mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Initial changes to support multiple search model configurations
- All search models are loaded into memory, and stored in a dictionary indexed by name - Still need to add database migrations and create a UI for user to select their choice. Presently, it uses the default option
This commit is contained in:
@@ -23,7 +23,7 @@ from starlette.authentication import (
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.database.adapters import (
|
||||
get_all_users,
|
||||
get_or_create_search_model,
|
||||
get_or_create_search_models,
|
||||
aget_user_subscription_state,
|
||||
SubscriptionState,
|
||||
)
|
||||
@@ -140,8 +140,14 @@ def configure_server(
|
||||
|
||||
# Initialize Search Models from Config and initialize content
|
||||
try:
|
||||
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
|
||||
search_models = get_or_create_search_models()
|
||||
state.embeddings_model = dict()
|
||||
state.cross_encoder_model = dict()
|
||||
|
||||
for model in search_models:
|
||||
state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)})
|
||||
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
|
||||
|
||||
state.SearchType = configure_search_types()
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
initialize_content(regenerate, search_type, init, user)
|
||||
|
||||
@@ -249,12 +249,19 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||
return config
|
||||
|
||||
|
||||
def get_or_create_search_model():
|
||||
search_model = SearchModelConfig.objects.filter().first()
|
||||
if not search_model:
|
||||
search_model = SearchModelConfig.objects.create()
|
||||
def get_default_search_model():
|
||||
if SearchModelConfig.objects.filter(name="default").exists():
|
||||
return SearchModelConfig.objects.filter(name="default").first()
|
||||
return SearchModelConfig.objects.first()
|
||||
|
||||
return search_model
|
||||
|
||||
def get_or_create_search_models():
|
||||
search_models = SearchModelConfig.objects.all()
|
||||
if search_models.count() == 0:
|
||||
SearchModelConfig.objects.create()
|
||||
search_models = SearchModelConfig.objects.all()
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
|
||||
@@ -12,6 +12,7 @@ from khoj.database.models import (
|
||||
SpeechToTextModelOptions,
|
||||
Subscription,
|
||||
ReflectiveQuestion,
|
||||
UserSearchModelConfig,
|
||||
)
|
||||
|
||||
admin.site.register(KhojUser, UserAdmin)
|
||||
@@ -23,3 +24,4 @@ admin.site.register(OfflineChatProcessorConversationConfig)
|
||||
admin.site.register(SearchModelConfig)
|
||||
admin.site.register(Subscription)
|
||||
admin.site.register(ReflectiveQuestion)
|
||||
admin.site.register(UserSearchModelConfig)
|
||||
|
||||
@@ -145,6 +145,11 @@ class UserConversationConfig(BaseModel):
|
||||
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class UserSearchModelConfig(BaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
conversation_log = models.JSONField(default=dict)
|
||||
|
||||
@@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.adapters import EntryAdapters, get_default_search_model
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -112,7 +112,8 @@ class TextToEntries(ABC):
|
||||
with timer("Generated embeddings for entries to add to database in", logger):
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings += self.embeddings_model.embed_documents(data_to_embed)
|
||||
model = get_default_search_model()
|
||||
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
||||
|
||||
added_entries: list[DbEntry] = []
|
||||
with timer("Added entries to database in", logger):
|
||||
|
||||
@@ -18,7 +18,7 @@ from starlette.authentication import requires
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_server
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
|
||||
from khoj.database.models import ChatModelOptions
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import (
|
||||
@@ -412,7 +412,8 @@ async def search(
|
||||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query)
|
||||
search_model = await sync_to_async(get_default_search_model)()
|
||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if t in [
|
||||
|
||||
@@ -19,7 +19,7 @@ from khoj.utils.state import SearchType
|
||||
from khoj.utils.rawconfig import SearchResponse, Entry
|
||||
from khoj.utils.jsonl import load_jsonl
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.adapters import EntryAdapters, get_default_search_model
|
||||
from khoj.database.models import KhojUser, Entry as DbEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -115,7 +115,8 @@ async def query(
|
||||
# Encode the query using the bi-encoder
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = state.embeddings_model.embed_query(query)
|
||||
search_model = await sync_to_async(get_default_search_model)()
|
||||
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
|
||||
|
||||
# Find relevant entries for the query
|
||||
top_k = 10
|
||||
|
||||
@@ -18,7 +18,7 @@ from khoj.utils.rawconfig import FullConfig
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
search_models = SearchModels()
|
||||
embeddings_model: EmbeddingsModel = None
|
||||
embeddings_model: Dict[str, EmbeddingsModel] = None
|
||||
cross_encoder_model: CrossEncoderModel = None
|
||||
content_index = ContentIndex()
|
||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||
|
||||
@@ -45,8 +45,10 @@ def enable_db_access_for_all_tests(db):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_config() -> SearchConfig:
|
||||
state.embeddings_model = EmbeddingsModel()
|
||||
state.cross_encoder_model = CrossEncoderModel()
|
||||
state.embeddings_model = dict()
|
||||
state.embeddings_model["default"] = EmbeddingsModel()
|
||||
state.cross_encoder_model = dict()
|
||||
state.cross_encoder_model["default"] = CrossEncoderModel()
|
||||
|
||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -317,8 +319,10 @@ def client(
|
||||
state.config.content_type = content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types()
|
||||
state.embeddings_model = EmbeddingsModel()
|
||||
state.cross_encoder_model = CrossEncoderModel()
|
||||
state.embeddings_model = dict()
|
||||
state.embeddings_model["default"] = EmbeddingsModel()
|
||||
state.cross_encoder_model = dict()
|
||||
state.cross_encoder_model["default"] = CrossEncoderModel()
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
Reference in New Issue
Block a user