Unify types used in Github and other text-based configs. Fix typing issues

This commit is contained in:
Saba
2023-06-13 15:52:36 -07:00
parent a6cd96a6a9
commit 08d79f5ba4
6 changed files with 19 additions and 12 deletions

View File

@@ -49,6 +49,7 @@ def configure_server(args, required=False):
# Initialize the search type and model from Config # Initialize the search type and model from Config
state.search_index_lock.acquire() state.search_index_lock.acquire()
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
state.model = SearchModels()
state.model = configure_search(state.model, state.config, args.regenerate) state.model = configure_search(state.model, state.config, args.regenerate)
state.search_index_lock.release() state.search_index_lock.release()

View File

@@ -163,7 +163,9 @@ class MainWindow(QtWidgets.QMainWindow):
processor_type_layout = QtWidgets.QVBoxLayout(processor_type_settings) processor_type_layout = QtWidgets.QVBoxLayout(processor_type_settings)
enable_conversation = ProcessorCheckBox(f"Conversation", processor_type) enable_conversation = ProcessorCheckBox(f"Conversation", processor_type)
# Add file browser to set input files for given processor type # Add file browser to set input files for given processor type
input_field = LabelledTextField("OpenAI API Key", processor_type, current_openai_api_key) input_field = LabelledTextField(
"OpenAI API Key", processor_type=processor_type, default_value=current_openai_api_key
)
# Set enabled/disabled based on checkbox state # Set enabled/disabled based on checkbox state
enable_conversation.setChecked(current_openai_api_key is not None) enable_conversation.setChecked(current_openai_api_key is not None)

View File

@@ -11,9 +11,9 @@ from khoj.utils import state
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GithubToJsonl: class GithubToJsonl(TextToJsonl):
def __init__(self, config: GithubContentConfig): def __init__(self, config: GithubContentConfig):
self.config = config super().__init__(config)
download_loader("GithubRepositoryReader") download_loader("GithubRepositoryReader")
def process(self, previous_entries=None): def process(self, previous_entries=None):

View File

@@ -6,14 +6,14 @@ from typing import Callable, List, Tuple
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
# Internal Packages # Internal Packages
from khoj.utils.rawconfig import Entry, TextContentConfig from khoj.utils.rawconfig import Entry, TextConfigBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextToJsonl(ABC): class TextToJsonl(ABC):
def __init__(self, config: TextContentConfig): def __init__(self, config: TextConfigBase):
self.config = config self.config = config
@abstractmethod @abstractmethod

View File

@@ -43,7 +43,10 @@ def get_config_types():
return [ return [
search_type.value search_type.value
for search_type in SearchType for search_type in SearchType
if search_type.value in configured_content_types if (
search_type.value in configured_content_types
and getattr(state.model, f"{search_type.value}_search") is not None
)
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
] ]

View File

@@ -16,11 +16,14 @@ class ConfigBase(BaseModel):
allow_population_by_field_name = True allow_population_by_field_name = True
class TextContentConfig(ConfigBase): class TextConfigBase(ConfigBase):
input_files: Optional[List[Path]]
input_filter: Optional[List[str]]
compressed_jsonl: Path compressed_jsonl: Path
embeddings_file: Path embeddings_file: Path
class TextContentConfig(TextConfigBase):
input_files: Optional[List[Path]]
input_filter: Optional[List[str]]
index_heading_entries: Optional[bool] = False index_heading_entries: Optional[bool] = False
@validator("input_filter") @validator("input_filter")
@@ -32,13 +35,11 @@ class TextContentConfig(ConfigBase):
return input_filter return input_filter
class GithubContentConfig(ConfigBase): class GithubContentConfig(TextConfigBase):
pat_token: str pat_token: str
repo_name: str repo_name: str
repo_owner: str repo_owner: str
repo_branch: Optional[str] = "master" repo_branch: Optional[str] = "master"
compressed_jsonl: Path
embeddings_file: Path
class ImageContentConfig(ConfigBase): class ImageContentConfig(ConfigBase):