From d73042426d70a7832d5f3f8e891e38c9dfe94f9b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 5 Mar 2023 15:43:27 -0600 Subject: [PATCH] Support filtering for results above threshold score in search API --- src/khoj/routers/api.py | 33 ++++++++++++++++++++++------ src/khoj/search_type/image_search.py | 6 ++++- src/khoj/search_type/text_search.py | 10 +++++++-- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 2d5329ea..637136a4 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -1,4 +1,5 @@ # Standard Packages +import math import yaml import logging from typing import List, Optional @@ -53,7 +54,13 @@ async def set_config_data(updated_config: FullConfig): @api.get("/search", response_model=List[SearchResponse]) -def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): +def search( + q: str, + n: Optional[int] = 5, + t: Optional[SearchType] = None, + r: Optional[bool] = False, + score_threshold: Optional[float | None] = None, +): results: List[SearchResponse] = [] if q is None or q == "": logger.warn(f"No query param (q) passed in API call to initiate search") @@ -62,9 +69,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti # initialize variables user_query = q.strip() results_count = n + score_threshold = score_threshold if score_threshold is not None else -math.inf # return cached results, if available - query_cache_key = f"{user_query}-{n}-{t}-{r}" + query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}" if query_cache_key in state.query_cache: logger.debug(f"Return response from query cache") return state.query_cache[query_cache_key] @@ -72,7 +80,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes with timer("Query took", logger): - hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) + hits, entries = text_search.query( + user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold + ) # collate and return results with timer("Collating results took", logger): @@ -81,7 +91,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti elif (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown files with timer("Query took", logger): - hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) + hits, entries = text_search.query( + user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold + ) # collate and return results with timer("Collating results took", logger): @@ -90,7 +102,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti elif (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions with timer("Query took", logger): - hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) + hits, entries = text_search.query( + user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold + ) # collate and return results with timer("Collating results took", logger): @@ -99,7 +113,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti elif (t == SearchType.Music or t == None) and state.model.music_search: # query music library with timer("Query took", logger): - hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) + hits, entries = text_search.query( + user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold + ) # collate and return results with timer("Collating results took", logger): @@ -108,7 +124,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti elif (t == SearchType.Image or t == None) and state.model.image_search: # query images with timer("Query took", logger): - hits = image_search.query(user_query, results_count, state.model.image_search) + hits = image_search.query( + user_query, results_count, state.model.image_search, score_threshold=score_threshold + ) output_directory = constants.web_directory / "images" # collate and return results @@ -129,6 +147,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti # Get plugin search model for specified search type, or the first one if none specified state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), rank_results=r, + score_threshold=score_threshold, ) # collate and return results diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py index 024dc79d..092353c7 100644 --- a/src/khoj/search_type/image_search.py +++ b/src/khoj/search_type/image_search.py @@ -1,5 +1,6 @@ # Standard Packages import glob +import math import pathlib import copy import shutil @@ -142,7 +143,7 @@ def extract_metadata(image_name): return image_processed_metadata -def query(raw_query, count, model: ImageSearchModel): +def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf): # Set query to image content if query is of form file:/path/to/file.png if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) @@ -198,6 +199,9 @@ def query(raw_query, count, model: ImageSearchModel): for corpus_id, scores in image_hits.items() ] + # Filter results by score threshold + hits = [hit for hit in hits if hit["image_score"] >= score_threshold] + # Sort the images based on their combined metadata, image scores return sorted(hits, key=lambda hit: hit["score"], reverse=True) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index af805c42..5bc430c8 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -1,5 +1,6 @@ # Standard Packages import logging +import math from pathlib import Path from typing import List, Tuple, Type @@ -99,7 +100,9 @@ def compute_embeddings( return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]: +def query( + raw_query: str, model: TextSearchModel, rank_results: bool = False, score_threshold: float = -math.inf +) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings @@ -129,6 +132,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> if rank_results: hits = cross_encoder_score(model.cross_encoder, query, entries, hits) + # Filter results by score threshold + hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold] + # Order results by cross-encoder score followed by bi-encoder score hits = sort_results(rank_results, hits) @@ -143,7 +149,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse] SearchResponse.parse_obj( { "entry": entries[hit["corpus_id"]].raw, - "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}", + "score": f"{hit.get('cross-score', 'score')}:.3f", "additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled}, } )