From 2cd3e799d3692dac9184bf682c72283afa0514ef Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 20 Jun 2023 22:22:43 -0700 Subject: [PATCH] Improve null and type checks --- src/khoj/configure.py | 23 ++++++++++++++--------- src/khoj/routers/api.py | 22 +++++++++++----------- src/khoj/search_type/text_search.py | 2 +- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 3aa39f10..df031dfa 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -3,6 +3,7 @@ import sys import logging import json from enum import Enum +from typing import Optional import requests # External Packages @@ -78,16 +79,20 @@ def configure_search_types(config: FullConfig): core_search_types = {e.name: e.value for e in SearchType} # Extract configured plugin search types plugin_search_types = {} - if config.content_type.plugins: + if config.content_type and config.content_type.plugins: plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()} # Dynamically generate search type enum by merging core search types with configured plugin search types return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) -def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: state.SearchType = None): +def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None): + if config.content_type is None or config.search_type is None: + logger.error("🚨 Content Type or Search Type not configured.") + return + # Initialize Org Notes Search - if (t == state.SearchType.Org or t == None) and config.content_type.org: + if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric: logger.info("🦄 Setting up search for orgmode notes") # Extract Entries, Generate Notes Embeddings model.org_search = text_search.setup( @@ -99,7 +104,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize Org Music Search - if (t == state.SearchType.Music or t == None) and config.content_type.music: + if (t == state.SearchType.Music or t == None) and config.content_type.music and config.search_type.asymmetric: logger.info("🎺 Setting up search for org-music") # Extract Entries, Generate Music Embeddings model.music_search = text_search.setup( @@ -111,7 +116,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize Markdown Search - if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown: + if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown and config.search_type.asymmetric: logger.info("💎 Setting up search for markdown notes") # Extract Entries, Generate Markdown Embeddings model.markdown_search = text_search.setup( @@ -123,7 +128,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize Ledger Search - if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger: + if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger and config.search_type.symmetric: logger.info("💸 Setting up search for ledger") # Extract Entries, Generate Ledger Embeddings model.ledger_search = text_search.setup( @@ -135,7 +140,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize PDF Search - if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf: + if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric: logger.info("🖨️ Setting up search for pdf") # Extract Entries, Generate PDF Embeddings model.pdf_search = text_search.setup( @@ -147,14 +152,14 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, ) # Initialize Image Search - if (t == state.SearchType.Image or t == None) and config.content_type.image: + if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image: logger.info("🌄 Setting up search for images") # Extract Entries, Generate Image Embeddings model.image_search = image_search.setup( config.content_type.image, search_config=config.search_type.image, regenerate=regenerate ) - if (t == state.SearchType.Github or t == None) and config.content_type.github: + if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric: logger.info("🐙 Setting up search for github") # Extract Entries, Generate Github Embeddings model.github_search = text_search.setup( diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 785b08c0..fc8ff7ce 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -134,7 +134,7 @@ async def search( user_query = q.strip() results_count = n score_threshold = score_threshold if score_threshold is not None else -math.inf - search_futures = defaultdict(list) + search_futures: list[concurrent.futures.Future] = [] # return cached results, if available query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" @@ -161,7 +161,7 @@ async def search( with concurrent.futures.ThreadPoolExecutor() as executor: if (t == SearchType.Org or t == None) and state.model.org_search: # query org-mode notes - search_futures[t] += [ + search_futures += [ executor.submit( text_search.query, user_query, @@ -175,7 +175,7 @@ async def search( if (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown notes - search_futures[t] += [ + search_futures += [ executor.submit( text_search.query, user_query, @@ -189,7 +189,7 @@ async def search( if (t == SearchType.Pdf or t == None) and state.model.pdf_search: # query pdf files - search_futures[t] += [ + search_futures += [ executor.submit( text_search.query, user_query, @@ -203,7 +203,7 @@ async def search( if (t == SearchType.Ledger) and state.model.ledger_search: # query transactions - search_futures[t] += [ + search_futures += [ executor.submit( text_search.query, user_query, @@ -216,7 +216,7 @@ async def search( if (t == SearchType.Music or t == None) and state.model.music_search: # query music library - search_futures[t] += [ + search_futures += [ executor.submit( text_search.query, user_query, @@ -230,7 +230,7 @@ async def search( if (t == SearchType.Image) and state.model.image_search: # query images - search_futures[t] += [ + search_futures += [ executor.submit( image_search.query, user_query, @@ -242,7 +242,7 @@ async def search( if (t is None or t in SearchType) and state.model.plugin_search: # query specified plugin type - search_futures[t] += [ + search_futures += [ executor.submit( text_search.query, user_query, @@ -257,7 +257,7 @@ async def search( # Query across each requested content types in parallel with timer("Query took", logger): - for search_future in concurrent.futures.as_completed(search_futures[t]): + for search_future in concurrent.futures.as_completed(search_futures): if t == SearchType.Image: hits = await search_future.result() output_directory = constants.web_directory / "images" @@ -288,7 +288,7 @@ async def search( state.previous_query = user_query end_time = time.time() - logger.debug(f"🔍 Search took: {end_time - start_time:.2f} seconds") + logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") return results @@ -297,7 +297,7 @@ async def search( def update(t: Optional[SearchType] = None, force: Optional[bool] = False, client: Optional[str] = None): try: state.search_index_lock.acquire() - state.model = configure_search(state.model, state.config, regenerate=force, t=t) + state.model = configure_search(state.model, state.config, regenerate=force or False, t=t) state.search_index_lock.release() except ValueError as e: logger.error(e) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 14e2015f..83f15918 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -181,7 +181,7 @@ def setup( previous_entries = ( extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None ) - entries_with_indices = text_to_jsonl(config).process(previous_entries) + entries_with_indices = text_to_jsonl(config).process(previous_entries or []) # Extract Updated Entries entries = extract_entries(config.compressed_jsonl)