diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 1a514226..4fd4be96 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import List, Optional # External Packages -from pydantic import BaseModel +from pydantic import BaseModel, validator # Internal Packages from src.utils.helpers import to_snake_case_from_dash @@ -14,17 +14,29 @@ class ConfigBase(BaseModel): allow_population_by_field_name = True class TextContentConfig(ConfigBase): - compressed_jsonl: Optional[Path] - input_files: Optional[List[str]] + input_files: Optional[List[Path]] input_filter: Optional[str] - embeddings_file: Optional[Path] + compressed_jsonl: Path + embeddings_file: Path + + @validator('input_filter') + def input_filter_or_files_required(cls, input_filter, values, **kwargs): + if input_filter is None and ('input_files' not in values or values["input_files"] is None): + raise ValueError("Either input_filter or input_files required in all content-type. section of Khoj config file") + return input_filter class ImageContentConfig(ConfigBase): - use_xmp_metadata: Optional[bool] - batch_size: Optional[int] input_directories: Optional[List[Path]] input_filter: Optional[str] - embeddings_file: Optional[Path] + embeddings_file: Path + use_xmp_metadata: bool + batch_size: int + + @validator('input_filter') + def input_filter_or_directories_required(cls, input_filter, values, **kwargs): + if input_filter is None and ('input_directories' not in values or values["input_directories"] is None): + raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file") + return input_filter class ContentConfig(ConfigBase): org: Optional[TextContentConfig] @@ -34,12 +46,12 @@ class ContentConfig(ConfigBase): markdown: Optional[TextContentConfig] class TextSearchConfig(ConfigBase): - encoder: Optional[str] - cross_encoder: Optional[str] + encoder: str + cross_encoder: str model_directory: Optional[Path] class ImageSearchConfig(ConfigBase): - encoder: Optional[str] + encoder: str model_directory: Optional[Path] class SearchConfig(ConfigBase): @@ -48,8 +60,8 @@ class SearchConfig(ConfigBase): image: Optional[ImageSearchConfig] class ConversationProcessorConfig(ConfigBase): - openai_api_key: Optional[str] - conversation_logfile: Optional[str] + openai_api_key: str + conversation_logfile: Path class ProcessorConfig(ConfigBase): conversation: Optional[ConversationProcessorConfig] diff --git a/tests/test_cli.py b/tests/test_cli.py index a266ebc0..5976dbdd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -40,5 +40,5 @@ def test_cli_config_from_file(): assert actual_args.config_file == Path('tests/data/config.yml') assert actual_args.regenerate == True assert actual_args.config is not None - assert actual_args.config.content_type.org.input_files == ['~/first_from_config.org', '~/second_from_config.org'] + assert actual_args.config.content_type.org.input_files == [Path('~/first_from_config.org'), Path('~/second_from_config.org')] assert actual_args.verbose == 3