mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Make search model configurable on server
- Expose ability to modify search model via Django admin interface - Previously the bi_encoder and cross_encoder models to use were set in code - Now it's user configurable but with a default config generated by default
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Optional, Type, TypeVar, List
|
from typing import Optional, Type, List
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Type, TypeVar, List
|
from typing import Type, List
|
||||||
from datetime import date, timezone
|
from datetime import date, timezone
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
@@ -31,6 +31,7 @@ from database.models import (
|
|||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
|
SearchModel,
|
||||||
Subscription,
|
Subscription,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
@@ -41,15 +42,6 @@ from khoj.search_filter.word_filter import WordFilter
|
|||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=models.Model)
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
|
|
||||||
instance = await model_class.objects.filter(id=id).afirst()
|
|
||||||
if not instance:
|
|
||||||
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
|
|
||||||
return instance
|
|
||||||
|
|
||||||
|
|
||||||
async def set_notion_config(token: str, user: KhojUser):
|
async def set_notion_config(token: str, user: KhojUser):
|
||||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||||
@@ -220,6 +212,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_search_model():
|
||||||
|
return SearchModel.objects.filter().get_or_create()[0]
|
||||||
|
|
||||||
|
|
||||||
class ConversationAdapters:
|
class ConversationAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_by_user(user: KhojUser):
|
def get_conversation_by_user(user: KhojUser):
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from database.models import (
|
|||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
|
SearchModel,
|
||||||
Subscription,
|
Subscription,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
|
|||||||
admin.site.register(ChatModelOptions)
|
admin.site.register(ChatModelOptions)
|
||||||
admin.site.register(OpenAIProcessorConversationConfig)
|
admin.site.register(OpenAIProcessorConversationConfig)
|
||||||
admin.site.register(OfflineChatProcessorConversationConfig)
|
admin.site.register(OfflineChatProcessorConversationConfig)
|
||||||
|
admin.site.register(SearchModel)
|
||||||
admin.site.register(Subscription)
|
admin.site.register(Subscription)
|
||||||
|
|||||||
32
src/database/migrations/0017_searchmodel.py
Normal file
32
src/database/migrations/0017_searchmodel.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# Generated by Django 4.2.5 on 2023-11-14 23:25
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0016_alter_subscription_renewal_date"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="SearchModel",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("name", models.CharField(default="default", max_length=200)),
|
||||||
|
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
|
||||||
|
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
|
||||||
|
(
|
||||||
|
"cross_encoder",
|
||||||
|
models.CharField(
|
||||||
|
blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -102,6 +102,18 @@ class LocalPlaintextConfig(BaseModel):
|
|||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchModel(BaseModel):
|
||||||
|
class ModelType(models.TextChoices):
|
||||||
|
TEXT = "text"
|
||||||
|
|
||||||
|
name = models.CharField(max_length=200, default="default")
|
||||||
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
||||||
|
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
||||||
|
cross_encoder = models.CharField(
|
||||||
|
max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2", null=True, blank=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConversationConfig(BaseModel):
|
class OpenAIProcessorConversationConfig(BaseModel):
|
||||||
api_key = models.CharField(max_length=200)
|
api_key = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import logging
|
|||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import Request
|
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -21,15 +20,16 @@ from starlette.authentication import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
from database.models import KhojUser, Subscription
|
||||||
|
from database.adapters import get_all_users, get_or_create_search_model
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
|
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import (
|
from khoj.utils.config import (
|
||||||
SearchType,
|
SearchType,
|
||||||
)
|
)
|
||||||
from khoj.utils.fs_syncer import collect_files
|
from khoj.utils.fs_syncer import collect_files
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
|
||||||
from database.models import KhojUser, Subscription
|
|
||||||
from database.adapters import get_all_users
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -113,6 +113,9 @@ def configure_server(
|
|||||||
|
|
||||||
# Initialize Search Models from Config and initialize content
|
# Initialize Search Models from Config and initialize content
|
||||||
try:
|
try:
|
||||||
|
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||||
|
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
|
||||||
|
|
||||||
state.config_lock.acquire()
|
state.config_lock.acquire()
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ from khoj.utils.rawconfig import SearchResponse
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsModel:
|
class EmbeddingsModel:
|
||||||
def __init__(self):
|
def __init__(self, model_name: str = "thenlper/gte-small"):
|
||||||
self.encode_kwargs = {"normalize_embeddings": True}
|
self.encode_kwargs = {"normalize_embeddings": True}
|
||||||
self.model_kwargs = {"device": get_device()}
|
self.model_kwargs = {"device": get_device()}
|
||||||
self.model_name = "thenlper/gte-small"
|
self.model_name = model_name
|
||||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
@@ -21,11 +21,11 @@ class EmbeddingsModel:
|
|||||||
|
|
||||||
|
|
||||||
class CrossEncoderModel:
|
class CrossEncoderModel:
|
||||||
def __init__(self):
|
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
||||||
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
self.model_name = model_name
|
||||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||||
|
|
||||||
def predict(self, query, hits: List[SearchResponse]):
|
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||||
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
cross__inp = [[query, hit.additional[key]] for hit in hits]
|
||||||
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
|
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
|
||||||
return cross_scores
|
return cross_scores
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from khoj.utils.rawconfig import Entry
|
|||||||
from khoj.processor.embeddings import EmbeddingsModel
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from database.models import KhojUser, Entry as DbEntry, EntryDates
|
from database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||||
from database.adapters import EntryAdapters
|
from database.adapters import EntryAdapters, get_or_create_search_model
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -22,7 +22,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class TextToEntries(ABC):
|
class TextToEntries(ABC):
|
||||||
def __init__(self, config: Any = None):
|
def __init__(self, config: Any = None):
|
||||||
self.embeddings_model = EmbeddingsModel()
|
bi_encoder_name = get_or_create_search_model().bi_encoder
|
||||||
|
self.embeddings_model = EmbeddingsModel(bi_encoder_name)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.date_filter = DateFilter()
|
self.date_filter = DateFilter()
|
||||||
|
|
||||||
|
|||||||
@@ -5,21 +5,20 @@ from typing import List, Dict
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import config as utils_config
|
from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import LRU, get_device
|
from khoj.utils.helpers import LRU, get_device
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
|
||||||
|
|
||||||
# Application Global State
|
# Application Global State
|
||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
embeddings_model = EmbeddingsModel()
|
embeddings_model: EmbeddingsModel = None
|
||||||
cross_encoder_model = CrossEncoderModel()
|
cross_encoder_model: CrossEncoderModel = None
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
|
|||||||
@@ -8,11 +8,13 @@ from fastapi import FastAPI
|
|||||||
import os
|
import os
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
|
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
@@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def search_config() -> SearchConfig:
|
def search_config() -> SearchConfig:
|
||||||
|
state.embeddings_model = EmbeddingsModel()
|
||||||
|
state.cross_encoder_model = CrossEncoderModel()
|
||||||
|
|
||||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||||
model_dir.mkdir(parents=True, exist_ok=True)
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
search_config = SearchConfig()
|
search_config = SearchConfig()
|
||||||
@@ -292,6 +297,8 @@ def client(
|
|||||||
state.config.content_type = content_config
|
state.config.content_type = content_config
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
|
state.embeddings_model = EmbeddingsModel()
|
||||||
|
state.cross_encoder_model = CrossEncoderModel()
|
||||||
|
|
||||||
# These lines help us Mock the Search models for these search types
|
# These lines help us Mock the Search models for these search types
|
||||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from database.models import (
|
|||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
|
SearchModel,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
Subscription,
|
Subscription,
|
||||||
@@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory):
|
|||||||
user = factory.SubFactory(UserFactory)
|
user = factory.SubFactory(UserFactory)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchModelFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = SearchModel
|
||||||
|
|
||||||
|
name = "default"
|
||||||
|
model_type = "text"
|
||||||
|
bi_encoder = "thenlper/gte-small"
|
||||||
|
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Subscription
|
model = Subscription
|
||||||
|
|||||||
Reference in New Issue
Block a user