diff --git a/src/khoj/configure.py b/src/khoj/configure.py index dfc0fe4f..e2f16884 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -156,7 +156,15 @@ def configure_server( ) } ) - state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)}) + state.cross_encoder_model.update( + { + model.name: CrossEncoderModel( + model.cross_encoder, + model.cross_encoder_inference_endpoint, + model.cross_encoder_inference_endpoint_api_key, + ) + } + ) state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) diff --git a/src/khoj/database/migrations/0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more.py b/src/khoj/database/migrations/0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more.py new file mode 100644 index 00000000..3c2d87ce --- /dev/null +++ b/src/khoj/database/migrations/0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more.py @@ -0,0 +1,22 @@ +# Generated by Django 4.2.7 on 2024-01-17 04:21 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0025_searchmodelconfig_embeddings_inference_endpoint_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="searchmodelconfig", + name="cross_encoder_inference_endpoint", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + migrations.AddField( + model_name="searchmodelconfig", + name="cross_encoder_inference_endpoint_api_key", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 2b8887f7..030e7ea8 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -112,6 +112,8 @@ class SearchModelConfig(BaseModel): cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) + cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) class TextToImageModelConfig(BaseModel): diff --git a/src/khoj/processor/content/pdf/pdf_to_entries.py b/src/khoj/processor/content/pdf/pdf_to_entries.py index efe57a21..3582cbe0 100644 --- a/src/khoj/processor/content/pdf/pdf_to_entries.py +++ b/src/khoj/processor/content/pdf/pdf_to_entries.py @@ -4,7 +4,7 @@ import os from datetime import datetime from typing import List, Tuple -from langchain.document_loaders import PyMuPDFLoader +from langchain_community.document_loaders import PyMuPDFLoader from khoj.database.models import Entry as DbEntry from khoj.database.models import KhojUser diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 58b3f583..c0a9929c 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -6,7 +6,7 @@ from typing import Any import openai from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.chat_models import ChatOpenAI +from langchain_community.chat_models import ChatOpenAI from tenacity import ( before_sleep_log, retry, diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index c0e91ce4..cc710a92 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -68,11 +68,29 @@ class EmbeddingsModel: class CrossEncoderModel: - def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): + def __init__( + self, + model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", + cross_encoder_inference_endpoint: str = None, + cross_encoder_inference_endpoint_api_key: str = None, + ): self.model_name = model_name self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device()) + self.inference_endpoint = cross_encoder_inference_endpoint + self.api_key = cross_encoder_inference_endpoint_api_key def predict(self, query, hits: List[SearchResponse], key: str = "compiled"): + if ( + self.api_key is not None + and self.inference_endpoint is not None + and "huggingface" in self.inference_endpoint + ): + target_url = f"{self.inference_endpoint}" + payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + response = requests.post(target_url, json=payload, headers=headers) + return response.json()["scores"] + cross_inp = [[query, hit.additional[key]] for hit in hits] cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid()) return cross_scores