diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 28999369..4b9b54ef 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,3 +1,4 @@ +import math from typing import Optional, Type, TypeVar, List from datetime import date, datetime, timedelta import secrets @@ -437,12 +438,19 @@ class EntryAdapters: @staticmethod def search_with_embeddings( - user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None + user: KhojUser, + embeddings: Tensor, + max_results: int = 10, + file_type_filter: str = None, + raw_query: str = None, + max_distance: float = math.inf, ): relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter) relevant_entries = relevant_entries.filter(user=user).annotate( distance=CosineDistance("embeddings", embeddings) ) + relevant_entries = relevant_entries.filter(distance__lte=max_distance) + if file_type_filter: relevant_entries = relevant_entries.filter(file_type=file_type_filter) relevant_entries = relevant_entries.order_by("distance") diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4e050eee..fbdfbd63 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -356,7 +356,7 @@ async def search( n: Optional[int] = 5, t: Optional[SearchType] = SearchType.All, r: Optional[bool] = False, - score_threshold: Optional[Union[float, None]] = None, + max_distance: Optional[Union[float, None]] = None, dedupe: Optional[bool] = True, client: Optional[str] = None, user_agent: Optional[str] = Header(None), @@ -375,12 +375,12 @@ async def search( # initialize variables user_query = q.strip() results_count = n or 5 - score_threshold = score_threshold if score_threshold is not None else -math.inf + max_distance = max_distance if max_distance is not None else math.inf search_futures: List[concurrent.futures.Future] = [] # return cached results, if available if user: - query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" + query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}" if query_cache_key in state.query_cache[user.uuid]: logger.debug(f"Return response from query cache") return state.query_cache[user.uuid][query_cache_key] @@ -418,7 +418,7 @@ async def search( t, question_embedding=encoded_asymmetric_query, rank_results=r or False, - score_threshold=score_threshold, + max_distance=max_distance, ) ] @@ -431,7 +431,6 @@ async def search( results_count, state.search_models.image_search, state.content_index.image, - score_threshold=score_threshold, ) ] @@ -454,11 +453,10 @@ async def search( # Collate results results += text_search.collate_results(hits, dedupe=dedupe) - if r: - results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count] - else: # Sort results across all content types and take top results - results = sorted(results, key=lambda x: float(x.score))[:results_count] + results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[ + :results_count + ] # Cache results if user: @@ -583,6 +581,7 @@ async def chat( request: Request, q: str, n: Optional[int] = 5, + d: Optional[float] = 0.15, client: Optional[str] = None, stream: Optional[bool] = False, user_agent: Optional[str] = Header(None), @@ -599,7 +598,7 @@ async def chat( meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - request, meta_log, q, (n or 5), conversation_command + request, meta_log, q, (n or 5), (d or math.inf), conversation_command ) if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): @@ -663,6 +662,7 @@ async def extract_references_and_questions( meta_log: dict, q: str, n: int, + d: float, conversation_type: ConversationCommand = ConversationCommand.Default, ): user = request.user.object if request.user.is_authenticated else None @@ -723,7 +723,7 @@ async def extract_references_and_questions( request=request, n=n_items, r=True, - score_threshold=-5.0, + max_distance=d, dedupe=False, ) ) diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py index d7f486af..214118fc 100644 --- a/src/khoj/search_type/image_search.py +++ b/src/khoj/search_type/image_search.py @@ -146,7 +146,7 @@ def extract_metadata(image_name): async def query( - raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf + raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf ): # Set query to image content if query is of form file:/path/to/file.png if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): @@ -167,7 +167,8 @@ async def query( # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. with timer("Search Time", logger): image_hits = { - result["corpus_id"]: {"image_score": result["score"], "score": result["score"]} + # Map scores to distance metric by multiplying by -1 + result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]} for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0] } @@ -204,7 +205,7 @@ async def query( ] # Filter results by score threshold - hits = [hit for hit in hits if hit["image_score"] >= score_threshold] + hits = [hit for hit in hits if hit["image_score"] <= score_threshold] # Sort the images based on their combined metadata, image scores return sorted(hits, key=lambda hit: hit["score"], reverse=True) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index ba2fc9ec..041c385f 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -105,7 +105,7 @@ async def query( type: SearchType = SearchType.All, question_embedding: Union[torch.Tensor, None] = None, rank_results: bool = False, - score_threshold: float = -math.inf, + max_distance: float = math.inf, ) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" @@ -127,6 +127,7 @@ async def query( max_results=top_k, file_type_filter=file_type, raw_query=raw_query, + max_distance=max_distance, ).all() hits = await sync_to_async(list)(hits) # type: ignore[call-arg] @@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]): ) -def rerank_and_sort_results(hits, query): +def rerank_and_sort_results(hits, query, rank_results): + # If we have more than one result and reranking is enabled + rank_results = rank_results and len(list(hits)) > 1 + # Score all retrieved entries using the cross-encoder - hits = cross_encoder_score(query, hits) + if rank_results: + hits = cross_encoder_score(query, hits) # Sort results by cross-encoder score followed by bi-encoder score - hits = sort_results(rank_results=True, hits=hits) + hits = sort_results(rank_results=rank_results, hits=hits) return hits @@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe with timer("Cross-Encoder Predict Time", logger, state.device): cross_scores = state.cross_encoder_model.predict(query, hits) - # Store cross-encoder scores in results dictionary for ranking + # Convert cross-encoder scores to distances and pass in hits for reranking for idx in range(len(cross_scores)): - hits[idx]["cross_score"] = cross_scores[idx] + hits[idx]["cross_score"] = -1 * cross_scores[idx] return hits @@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]: """Order results by cross-encoder score followed by bi-encoder score""" with timer("Rank Time", logger, state.device): - hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score + hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score if rank_results: - hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score + hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score return hits