diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 266eaed0..822dd278 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -1,5 +1,4 @@ # Standard Packages -from collections import defaultdict import concurrent.futures import math import time @@ -21,6 +20,7 @@ from khoj.search_type import image_search, text_search from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter +from khoj.utils.config import TextSearchModel from khoj.utils.helpers import log_telemetry, timer from khoj.utils.rawconfig import ( ContentConfig, @@ -144,10 +144,14 @@ async def search( ): start_time = time.time() + # Run validation checks results: List[SearchResponse] = [] if q is None or q == "": logger.warn(f"No query param (q) passed in API call to initiate search") return results + if not state.model or not any(state.model.__dict__.values()): + logger.warn(f"No search models loaded. Configure a search model before initiating search") + return results # initialize variables user_query = q.strip() @@ -168,14 +172,20 @@ async def search( encoded_asymmetric_query = None if t == SearchType.All or (t != SearchType.Ledger and t != SearchType.Image): - with timer("Encoding query took", logger=logger): - encoded_asymmetric_query = util.normalize_embeddings( - state.model.org_search.bi_encoder.encode( - [defiltered_query], - convert_to_tensor=True, - device=state.device, + text_search_models: List[TextSearchModel] = [ + model + for model_name, model in state.model.__dict__.items() + if isinstance(model, TextSearchModel) and model_name != "ledger_search" + ] + if text_search_models: + with timer("Encoding query took", logger=logger): + encoded_asymmetric_query = util.normalize_embeddings( + text_search_models[0].bi_encoder.encode( + [defiltered_query], + convert_to_tensor=True, + device=state.device, + ) ) - ) with concurrent.futures.ThreadPoolExecutor() as executor: if (t == SearchType.Org or t == SearchType.All) and state.model.org_search: diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py index 092353c7..d6cc33d6 100644 --- a/src/khoj/search_type/image_search.py +++ b/src/khoj/search_type/image_search.py @@ -143,7 +143,7 @@ def extract_metadata(image_name): return image_processed_metadata -def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf): +async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf): # Set query to image content if query is of form file:/path/to/file.png if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)