From e919d28f1c161babc61874a4314c3ea930bcec63 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 24 Aug 2024 19:28:26 -0700 Subject: [PATCH 1/2] Add support for custom search model-specific thresholds --- ...delconfig_bi_encoder_confidence_threshold.py | 17 +++++++++++++++++ src/khoj/database/models/__init__.py | 2 ++ src/khoj/routers/api.py | 1 - src/khoj/routers/api_chat.py | 4 ++-- src/khoj/search_type/text_search.py | 9 +++++++-- 5 files changed, 28 insertions(+), 5 deletions(-) create mode 100644 src/khoj/database/migrations/0059_searchmodelconfig_bi_encoder_confidence_threshold.py diff --git a/src/khoj/database/migrations/0059_searchmodelconfig_bi_encoder_confidence_threshold.py b/src/khoj/database/migrations/0059_searchmodelconfig_bi_encoder_confidence_threshold.py new file mode 100644 index 00000000..24ea656b --- /dev/null +++ b/src/khoj/database/migrations/0059_searchmodelconfig_bi_encoder_confidence_threshold.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.7 on 2024-08-24 18:19 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0058_alter_chatmodeloptions_chat_model"), + ] + + operations = [ + migrations.AddField( + model_name="searchmodelconfig", + name="bi_encoder_confidence_threshold", + field=models.FloatField(default=0.18), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 2468ffc9..99e5191a 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -270,6 +270,8 @@ class SearchModelConfig(BaseModel): cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) # Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + # The confidence threshold of the bi_encoder model to consider the embeddings as relevant + bi_encoder_confidence_threshold = models.FloatField(default=0.18) class TextToImageModelConfig(BaseModel): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 953449d3..2baf6838 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -117,7 +117,6 @@ async def execute_search( # initialize variables user_query = q.strip() results_count = n or 5 - max_distance = max_distance or math.inf search_futures: List[concurrent.futures.Future] = [] # return cached results, if available diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e1fa4d5c..39c1e48b 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -524,7 +524,7 @@ async def chat( common: CommonQueryParams, q: str, n: int = 7, - d: float = 0.18, + d: float = None, stream: Optional[bool] = False, title: Optional[str] = None, conversation_id: Optional[int] = None, @@ -764,7 +764,7 @@ async def chat( meta_log, q, (n or 7), - (d or 0.18), + d, conversation_id, conversation_commands, location, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 93a2b724..569f0b50 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -100,18 +100,23 @@ async def query( raw_query: str, type: SearchType = SearchType.All, question_embedding: Union[torch.Tensor, None] = None, - max_distance: float = math.inf, + max_distance: float = None, ) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" file_type = search_type_to_embeddings_type[type.value] query = raw_query + search_model = await sync_to_async(get_user_search_model_or_default)(user) + if not max_distance: + if search_model.bi_encoder_confidence_threshold: + max_distance = search_model.bi_encoder_confidence_threshold + else: + max_distance = math.inf # Encode the query using the bi-encoder if question_embedding is None: with timer("Query Encode Time", logger, state.device): - search_model = await sync_to_async(get_user_search_model_or_default)(user) question_embedding = state.embeddings_model[search_model.name].embed_query(query) # Find relevant entries for the query From 4b77325f637ff7f95ddee27f60910bf8783e20f0 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 24 Aug 2024 19:57:49 -0700 Subject: [PATCH 2/2] Default to infinite distance when using the search API --- src/khoj/routers/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 2baf6838..5fc79bcd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -83,7 +83,7 @@ async def search( n=n, t=t, r=r, - max_distance=max_distance, + max_distance=max_distance or math.inf, dedupe=dedupe, )