Configure, use dynamically instantiated SearchType enum on app start

The SearchType is now dynamically populated with core and configured
plugin types

Use the new dynamic SearchType enum from state.py across codebase
This commit is contained in:
Debanjum Singh Solanky
2023-02-24 01:25:20 -06:00
parent ab0d3a08e2
commit 47b58a2a4d
4 changed files with 24 additions and 11 deletions

View File

@@ -2,6 +2,7 @@
import sys import sys
import logging import logging
import json import json
from enum import Enum
# External Packages # External Packages
import schedule import schedule
@@ -14,7 +15,7 @@ from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import LRU, resolve_absolute_path from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
from khoj.utils.rawconfig import FullConfig, ProcessorConfig from khoj.utils.rawconfig import FullConfig, ProcessorConfig
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
@@ -40,8 +41,9 @@ def configure_server(args, required=False):
# Initialize Processor from Config # Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor) state.processor_config = configure_processor(args.config.processor)
# Initialize the search model from Config # Initialize the search type and model from Config
state.search_index_lock.acquire() state.search_index_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.model = configure_search(state.model, state.config, args.regenerate) state.model = configure_search(state.model, state.config, args.regenerate)
state.search_index_lock.release() state.search_index_lock.release()
@@ -54,9 +56,19 @@ def update_search_index():
logger.info("Search Index updated via Scheduler") logger.info("Search Index updated via Scheduler")
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None): def configure_search_types(config: FullConfig):
# Extract core search types
core_search_types = {e.name: e.value for e in SearchType}
# Extract configured plugin search types
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
# Dynamically generate search type enum by merging core search types with configured plugin search types
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: state.SearchType = None):
# Initialize Org Notes Search # Initialize Org Notes Search
if (t == SearchType.Org or t == None) and config.content_type.org: if (t == state.SearchType.Org or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup( model.orgmode_search = text_search.setup(
OrgToJsonl, OrgToJsonl,
@@ -67,7 +79,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize Org Music Search # Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music: if (t == state.SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings # Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup( model.music_search = text_search.setup(
OrgToJsonl, OrgToJsonl,
@@ -78,7 +90,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize Markdown Search # Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown: if (t == state.SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup( model.markdown_search = text_search.setup(
MarkdownToJsonl, MarkdownToJsonl,
@@ -89,7 +101,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize Ledger Search # Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger: if (t == state.SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings # Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup( model.ledger_search = text_search.setup(
BeancountToJsonl, BeancountToJsonl,
@@ -100,7 +112,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
) )
# Initialize Image Search # Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image: if (t == state.SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup( model.image_search = image_search.setup(
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate config.content_type.image, search_config=config.search_type.image, regenerate=regenerate

View File

@@ -12,10 +12,9 @@ from khoj.configure import configure_processor, configure_search
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import FullConfig, SearchResponse from khoj.utils.rawconfig import FullConfig, SearchResponse
from khoj.utils.config import SearchType from khoj.utils.state import SearchType
from khoj.utils import state, constants from khoj.utils import state, constants
# Initialize Router # Initialize Router
api = APIRouter() api = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -17,7 +17,7 @@ from khoj.processor.conversation.gpt import (
understand, understand,
summarize, summarize,
) )
from khoj.utils.config import SearchType from khoj.utils.state import SearchType
from khoj.utils.helpers import get_from_dict, resolve_absolute_path from khoj.utils.helpers import get_from_dict, resolve_absolute_path
from khoj.utils import state from khoj.utils import state

View File

@@ -8,6 +8,7 @@ import torch
from pathlib import Path from pathlib import Path
# Internal Packages # Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import SearchModels, ProcessorConfigModel from khoj.utils.config import SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
@@ -23,6 +24,7 @@ port: int = None
cli_args: List[str] = None cli_args: List[str] = None
query_cache = LRU() query_cache = LRU()
search_index_lock = threading.Lock() search_index_lock = threading.Lock()
SearchType = utils_config.SearchType
if torch.cuda.is_available(): if torch.cuda.is_available():
# Use CUDA GPU # Use CUDA GPU