diff --git a/src/khoj/database/migrations/0059_searchmodelconfig_bi_encoder_confidence_threshold.py b/src/khoj/database/migrations/0059_searchmodelconfig_bi_encoder_confidence_threshold.py new file mode 100644 index 00000000..24ea656b --- /dev/null +++ b/src/khoj/database/migrations/0059_searchmodelconfig_bi_encoder_confidence_threshold.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.7 on 2024-08-24 18:19 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0058_alter_chatmodeloptions_chat_model"), + ] + + operations = [ + migrations.AddField( + model_name="searchmodelconfig", + name="bi_encoder_confidence_threshold", + field=models.FloatField(default=0.18), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 2468ffc9..99e5191a 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -270,6 +270,8 @@ class SearchModelConfig(BaseModel): cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) # Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + # The confidence threshold of the bi_encoder model to consider the embeddings as relevant + bi_encoder_confidence_threshold = models.FloatField(default=0.18) class TextToImageModelConfig(BaseModel): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 953449d3..5fc79bcd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -83,7 +83,7 @@ async def search( n=n, t=t, r=r, - max_distance=max_distance, + max_distance=max_distance or math.inf, dedupe=dedupe, ) @@ -117,7 +117,6 @@ async def execute_search( # initialize variables user_query = q.strip() results_count = n or 5 - max_distance = max_distance or math.inf search_futures: List[concurrent.futures.Future] = [] # return cached results, if available diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e1fa4d5c..39c1e48b 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -524,7 +524,7 @@ async def chat( common: CommonQueryParams, q: str, n: int = 7, - d: float = 0.18, + d: float = None, stream: Optional[bool] = False, title: Optional[str] = None, conversation_id: Optional[int] = None, @@ -764,7 +764,7 @@ async def chat( meta_log, q, (n or 7), - (d or 0.18), + d, conversation_id, conversation_commands, location, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 93a2b724..569f0b50 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -100,18 +100,23 @@ async def query( raw_query: str, type: SearchType = SearchType.All, question_embedding: Union[torch.Tensor, None] = None, - max_distance: float = math.inf, + max_distance: float = None, ) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" file_type = search_type_to_embeddings_type[type.value] query = raw_query + search_model = await sync_to_async(get_user_search_model_or_default)(user) + if not max_distance: + if search_model.bi_encoder_confidence_threshold: + max_distance = search_model.bi_encoder_confidence_threshold + else: + max_distance = math.inf # Encode the query using the bi-encoder if question_embedding is None: with timer("Query Encode Time", logger, state.device): - search_model = await sync_to_async(get_user_search_model_or_default)(user) question_embedding = state.embeddings_model[search_model.name].embed_query(query) # Find relevant entries for the query