Make search model configurable on server

- Expose ability to modify search model via Django admin interface
- Previously the bi_encoder and cross_encoder models to use were set
  in code
- Now it's user configurable but with a default config generated by
  default
This commit is contained in:
Debanjum Singh Solanky
2023-11-14 16:56:26 -08:00
parent b734984d6d
commit 4af194d74b
10 changed files with 91 additions and 28 deletions

View File

@@ -1,8 +1,8 @@
import math import math
from typing import Optional, Type, TypeVar, List from typing import Optional, Type, List
from datetime import date, datetime, timedelta from datetime import date, datetime
import secrets import secrets
from typing import Type, TypeVar, List from typing import Type, List
from datetime import date, timezone from datetime import date, timezone
from django.db import models from django.db import models
@@ -31,6 +31,7 @@ from database.models import (
GithubRepoConfig, GithubRepoConfig,
Conversation, Conversation,
ChatModelOptions, ChatModelOptions,
SearchModel,
Subscription, Subscription,
UserConversationConfig, UserConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
@@ -41,15 +42,6 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
ModelType = TypeVar("ModelType", bound=models.Model)
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
instance = await model_class.objects.filter(id=id).afirst()
if not instance:
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
return instance
async def set_notion_config(token: str, user: KhojUser): async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst() notion_config = await NotionConfig.objects.filter(user=user).afirst()
@@ -220,6 +212,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config return config
def get_or_create_search_model():
return SearchModel.objects.filter().get_or_create()[0]
class ConversationAdapters: class ConversationAdapters:
@staticmethod @staticmethod
def get_conversation_by_user(user: KhojUser): def get_conversation_by_user(user: KhojUser):

View File

@@ -8,6 +8,7 @@ from database.models import (
ChatModelOptions, ChatModelOptions,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
SearchModel,
Subscription, Subscription,
) )
@@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions) admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModel)
admin.site.register(Subscription) admin.site.register(Subscription)

View File

@@ -0,0 +1,32 @@
# Generated by Django 4.2.5 on 2023-11-14 23:25
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0016_alter_subscription_renewal_date"),
]
operations = [
migrations.CreateModel(
name="SearchModel",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(default="default", max_length=200)),
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
(
"cross_encoder",
models.CharField(
blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True
),
),
],
options={
"abstract": False,
},
),
]

View File

@@ -102,6 +102,18 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class SearchModel(BaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
name = models.CharField(max_length=200, default="default")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
cross_encoder = models.CharField(
max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2", null=True, blank=True
)
class OpenAIProcessorConversationConfig(BaseModel): class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200) api_key = models.CharField(max_length=200)

View File

@@ -3,7 +3,6 @@ import logging
import json import json
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
from fastapi import Request
import requests import requests
import os import os
@@ -21,15 +20,16 @@ from starlette.authentication import (
) )
# Internal Packages # Internal Packages
from database.models import KhojUser, Subscription
from database.adapters import get_all_users, get_or_create_search_model
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import ( from khoj.utils.config import (
SearchType, SearchType,
) )
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser, Subscription
from database.adapters import get_all_users
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -113,6 +113,9 @@ def configure_server(
# Initialize Search Models from Config and initialize content # Initialize Search Models from Config and initialize content
try: try:
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
state.config_lock.acquire() state.config_lock.acquire()
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type) state.search_models = configure_search(state.search_models, state.config.search_type)

View File

@@ -7,10 +7,10 @@ from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel: class EmbeddingsModel:
def __init__(self): def __init__(self, model_name: str = "thenlper/gte-small"):
self.encode_kwargs = {"normalize_embeddings": True} self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()} self.model_kwargs = {"device": get_device()}
self.model_name = "thenlper/gte-small" self.model_name = model_name
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def embed_query(self, query): def embed_query(self, query):
@@ -21,11 +21,11 @@ class EmbeddingsModel:
class CrossEncoderModel: class CrossEncoderModel:
def __init__(self): def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2" self.model_name = model_name
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device()) self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
def predict(self, query, hits: List[SearchResponse]): def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits] cross__inp = [[query, hit.additional[key]] for hit in hits]
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True) cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
return cross_scores return cross_scores

View File

@@ -14,7 +14,7 @@ from khoj.utils.rawconfig import Entry
from khoj.processor.embeddings import EmbeddingsModel from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Entry as DbEntry, EntryDates from database.models import KhojUser, Entry as DbEntry, EntryDates
from database.adapters import EntryAdapters from database.adapters import EntryAdapters, get_or_create_search_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,7 +22,8 @@ logger = logging.getLogger(__name__)
class TextToEntries(ABC): class TextToEntries(ABC):
def __init__(self, config: Any = None): def __init__(self, config: Any = None):
self.embeddings_model = EmbeddingsModel() bi_encoder_name = get_or_create_search_model().bi_encoder
self.embeddings_model = EmbeddingsModel(bi_encoder_name)
self.config = config self.config = config
self.date_filter = DateFilter() self.date_filter = DateFilter()

View File

@@ -5,21 +5,20 @@ from typing import List, Dict
from collections import defaultdict from collections import defaultdict
# External Packages # External Packages
import torch
from pathlib import Path from pathlib import Path
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
# Internal Packages # Internal Packages
from khoj.utils import config as utils_config from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device from khoj.utils.helpers import LRU, get_device
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
search_models = SearchModels() search_models = SearchModels()
embeddings_model = EmbeddingsModel() embeddings_model: EmbeddingsModel = None
cross_encoder_model = CrossEncoderModel() cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex() content_index = ContentIndex()
gpt4all_processor_config: GPT4AllProcessorModel = None gpt4all_processor_config: GPT4AllProcessorModel = None
config_file: Path = None config_file: Path = None

View File

@@ -8,11 +8,13 @@ from fastapi import FastAPI
import os import os
from fastapi import FastAPI from fastapi import FastAPI
app = FastAPI() app = FastAPI()
# Internal Packages # Internal Packages
from khoj.configure import configure_routes, configure_search_types, configure_middleware from khoj.configure import configure_routes, configure_search_types, configure_middleware
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels
@@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def search_config() -> SearchConfig: def search_config() -> SearchConfig:
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
model_dir = resolve_absolute_path("~/.khoj/search") model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig() search_config = SearchConfig()
@@ -292,6 +297,8 @@ def client(
state.config.content_type = content_config state.config.content_type = content_config
state.config.search_type = search_config state.config.search_type = search_config
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
# These lines help us Mock the Search models for these search types # These lines help us Mock the Search models for these search types
state.search_models.image_search = image_search.initialize_model(search_config.image) state.search_models.image_search = image_search.initialize_model(search_config.image)

View File

@@ -7,6 +7,7 @@ from database.models import (
ChatModelOptions, ChatModelOptions,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
SearchModel,
UserConversationConfig, UserConversationConfig,
Conversation, Conversation,
Subscription, Subscription,
@@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory):
user = factory.SubFactory(UserFactory) user = factory.SubFactory(UserFactory)
class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = SearchModel
name = "default"
model_type = "text"
bi_encoder = "thenlper/gte-small"
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
class SubscriptionFactory(factory.django.DjangoModelFactory): class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta: class Meta:
model = Subscription model = Subscription