From 285d17af2add96208e2aa39c6053bd2e301736f1 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 6 Jun 2023 19:28:54 +0530 Subject: [PATCH] Search in parallel across all enabled content types requested via API - Update API to return content from all enabled content types when type is not set to specific type in HTTP request param - To do this efficiently run the search queries in parallel threads --- src/khoj/routers/api.py | 211 ++++++++++++++++++++++------------------ 1 file changed, 114 insertions(+), 97 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f7658caa..93fa0fda 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -1,4 +1,6 @@ # Standard Packages +from collections import defaultdict +import concurrent.futures import math import yaml import logging @@ -121,6 +123,7 @@ def search( user_query = q.strip() results_count = n score_threshold = score_threshold if score_threshold is not None else -math.inf + search_futures = defaultdict(list) # return cached results, if available query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" @@ -128,105 +131,119 @@ def search( logger.debug(f"Return response from query cache") return state.query_cache[query_cache_key] - if (t == SearchType.Org or t == None) and state.model.org_search: - # query org-mode notes + with concurrent.futures.ThreadPoolExecutor() as executor: + if (t == SearchType.Org or t == None) and state.model.org_search: + # query org-mode notes + search_futures[t] += [ + executor.submit( + text_search.query, + user_query, + state.model.org_search, + rank_results=r, + score_threshold=score_threshold, + dedupe=dedupe, + ) + ] + + if (t == SearchType.Markdown or t == None) and state.model.markdown_search: + # query markdown notes + search_futures[t] += [ + executor.submit( + text_search.query, + user_query, + state.model.markdown_search, + rank_results=r, + score_threshold=score_threshold, + dedupe=dedupe, + ) + ] + + if (t == SearchType.Pdf or t == None) and state.model.pdf_search: + # query pdf files + search_futures[t] += [ + executor.submit( + text_search.query, + user_query, + state.model.pdf_search, + rank_results=r, + score_threshold=score_threshold, + dedupe=dedupe, + ) + ] + + if (t == SearchType.Ledger or t == None) and state.model.ledger_search: + # query transactions + search_futures[t] += [ + executor.submit( + text_search.query, + user_query, + state.model.ledger_search, + rank_results=r, + score_threshold=score_threshold, + dedupe=dedupe, + ) + ] + + if (t == SearchType.Music or t == None) and state.model.music_search: + # query music library + search_futures[t] += [ + executor.submit( + text_search.query, + user_query, + state.model.music_search, + rank_results=r, + score_threshold=score_threshold, + dedupe=dedupe, + ) + ] + + if (t == SearchType.Image) and state.model.image_search: + # query images + search_futures[t] += [ + executor.submit( + image_search.query, + user_query, + results_count, + state.model.image_search, + score_threshold=score_threshold, + ) + ] + + if (t is None or t in SearchType) and state.model.plugin_search: + # query specified plugin type + search_future[t] += [ + executor.submit( + text_search.query, + user_query, + # Get plugin search model for specified search type, or the first one if none specified + state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), + rank_results=r, + score_threshold=score_threshold, + dedupe=dedupe, + ) + ] + + # Query across each requested content types in parallel with timer("Query took", logger): - hits, entries = text_search.query( - user_query, state.model.org_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe - ) + for search_future in search_futures[t]: + if t == SearchType.Image: + hits = search_futures.result() + output_directory = constants.web_directory / "images" + # Collate results + results += image_search.collate_results( + hits, + image_names=state.model.image_search.image_names, + output_directory=output_directory, + image_files_url="/static/images", + count=results_count, + ) + else: + hits, entries = search_future.result() + # Collate results + results += text_search.collate_results(hits, entries, results_count) - # collate and return results - with timer("Collating results took", logger): - results = text_search.collate_results(hits, entries, results_count) - - elif (t == SearchType.Markdown or t == None) and state.model.markdown_search: - # query markdown files - with timer("Query took", logger): - hits, entries = text_search.query( - user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe - ) - - # collate and return results - with timer("Collating results took", logger): - results = text_search.collate_results(hits, entries, results_count) - - elif (t == SearchType.Pdf or t == None) and state.model.pdf_search: - # query pdf files - with timer("Query took", logger): - hits, entries = text_search.query( - user_query, state.model.pdf_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe - ) - - # collate and return results - with timer("Collating results took", logger): - results = text_search.collate_results(hits, entries, results_count) - - elif (t == SearchType.Github or t == None) and state.model.github_search: - # query github embeddings - with timer("Query took", logger): - hits, entries = text_search.query( - user_query, state.model.github_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe - ) - - # collate and return results - with timer("Collating results took", logger): - results = text_search.collate_results(hits, entries, results_count) - - elif (t == SearchType.Ledger or t == None) and state.model.ledger_search: - # query transactions - with timer("Query took", logger): - hits, entries = text_search.query( - user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe - ) - - # collate and return results - with timer("Collating results took", logger): - results = text_search.collate_results(hits, entries, results_count) - - elif (t == SearchType.Music or t == None) and state.model.music_search: - # query music library - with timer("Query took", logger): - hits, entries = text_search.query( - user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe - ) - - # collate and return results - with timer("Collating results took", logger): - results = text_search.collate_results(hits, entries, results_count) - - elif (t == SearchType.Image or t == None) and state.model.image_search: - # query images - with timer("Query took", logger): - hits = image_search.query( - user_query, results_count, state.model.image_search, score_threshold=score_threshold - ) - output_directory = constants.web_directory / "images" - - # collate and return results - with timer("Collating results took", logger): - results = image_search.collate_results( - hits, - image_names=state.model.image_search.image_names, - output_directory=output_directory, - image_files_url="/static/images", - count=results_count, - ) - - elif (t in SearchType or t == None) and state.model.plugin_search: - # query specified plugin type - with timer("Query took", logger): - hits, entries = text_search.query( - user_query, - # Get plugin search model for specified search type, or the first one if none specified - state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), - rank_results=r, - score_threshold=score_threshold, - dedupe=dedupe, - ) - - # collate and return results - with timer("Collating results took", logger): - results = text_search.collate_results(hits, entries, results_count) + # Sort results across all content types + results.sort(key=lambda x: float(x.score), reverse=True) # Cache results state.query_cache[query_cache_key] = results