From 08d79f5ba4e57c3d32c76dbb896559508639dba0 Mon Sep 17 00:00:00 2001 From: Saba Date: Tue, 13 Jun 2023 15:52:36 -0700 Subject: [PATCH] Unify types used in Github and other text-based configs. Fix typing issues --- src/khoj/configure.py | 1 + src/khoj/interface/desktop/main_window.py | 4 +++- src/khoj/processor/github/github_to_jsonl.py | 4 ++-- src/khoj/processor/text_to_jsonl.py | 4 ++-- src/khoj/routers/api.py | 5 ++++- src/khoj/utils/rawconfig.py | 13 +++++++------ 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index bf2de2e2..f9735fea 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -49,6 +49,7 @@ def configure_server(args, required=False): # Initialize the search type and model from Config state.search_index_lock.acquire() state.SearchType = configure_search_types(state.config) + state.model = SearchModels() state.model = configure_search(state.model, state.config, args.regenerate) state.search_index_lock.release() diff --git a/src/khoj/interface/desktop/main_window.py b/src/khoj/interface/desktop/main_window.py index 6fc061bd..5a3df3ec 100644 --- a/src/khoj/interface/desktop/main_window.py +++ b/src/khoj/interface/desktop/main_window.py @@ -163,7 +163,9 @@ class MainWindow(QtWidgets.QMainWindow): processor_type_layout = QtWidgets.QVBoxLayout(processor_type_settings) enable_conversation = ProcessorCheckBox(f"Conversation", processor_type) # Add file browser to set input files for given processor type - input_field = LabelledTextField("OpenAI API Key", processor_type, current_openai_api_key) + input_field = LabelledTextField( + "OpenAI API Key", processor_type=processor_type, default_value=current_openai_api_key + ) # Set enabled/disabled based on checkbox state enable_conversation.setChecked(current_openai_api_key is not None) diff --git a/src/khoj/processor/github/github_to_jsonl.py b/src/khoj/processor/github/github_to_jsonl.py index 6886d9a9..b989c12f 100644 --- a/src/khoj/processor/github/github_to_jsonl.py +++ b/src/khoj/processor/github/github_to_jsonl.py @@ -11,9 +11,9 @@ from khoj.utils import state logger = logging.getLogger(__name__) -class GithubToJsonl: +class GithubToJsonl(TextToJsonl): def __init__(self, config: GithubContentConfig): - self.config = config + super().__init__(config) download_loader("GithubRepositoryReader") def process(self, previous_entries=None): diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py index d85d6998..f7bca376 100644 --- a/src/khoj/processor/text_to_jsonl.py +++ b/src/khoj/processor/text_to_jsonl.py @@ -6,14 +6,14 @@ from typing import Callable, List, Tuple from khoj.utils.helpers import timer # Internal Packages -from khoj.utils.rawconfig import Entry, TextContentConfig +from khoj.utils.rawconfig import Entry, TextConfigBase logger = logging.getLogger(__name__) class TextToJsonl(ABC): - def __init__(self, config: TextContentConfig): + def __init__(self, config: TextConfigBase): self.config = config @abstractmethod diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 1f98496c..dec496d6 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -43,7 +43,10 @@ def get_config_types(): return [ search_type.value for search_type in SearchType - if search_type.value in configured_content_types + if ( + search_type.value in configured_content_types + and getattr(state.model, f"{search_type.value}_search") is not None + ) or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) ] diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 371918b6..21ff93d5 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -16,11 +16,14 @@ class ConfigBase(BaseModel): allow_population_by_field_name = True -class TextContentConfig(ConfigBase): - input_files: Optional[List[Path]] - input_filter: Optional[List[str]] +class TextConfigBase(ConfigBase): compressed_jsonl: Path embeddings_file: Path + + +class TextContentConfig(TextConfigBase): + input_files: Optional[List[Path]] + input_filter: Optional[List[str]] index_heading_entries: Optional[bool] = False @validator("input_filter") @@ -32,13 +35,11 @@ class TextContentConfig(ConfigBase): return input_filter -class GithubContentConfig(ConfigBase): +class GithubContentConfig(TextConfigBase): pat_token: str repo_name: str repo_owner: str repo_branch: Optional[str] = "master" - compressed_jsonl: Path - embeddings_file: Path class ImageContentConfig(ConfigBase):