diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 4f5f09da..ea391105 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -31,7 +31,7 @@ from database.models import ( GithubRepoConfig, Conversation, ChatModelOptions, - SearchModel, + SearchModelConfig, Subscription, UserConversationConfig, OpenAIProcessorConversationConfig, @@ -73,11 +73,11 @@ async def delete_khoj_token(user: KhojUser, token: str): async def get_or_create_user(token: dict) -> KhojUser: user = await get_user_by_token(token) if not user: - user = await create_user_by_token(token) + user = await create_user_by_google_token(token) return user -async def create_user_by_token(token: dict) -> KhojUser: +async def create_user_by_google_token(token: dict) -> KhojUser: user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create( defaults={"username": token.get("email"), "email": token.get("email")} ) @@ -214,9 +214,9 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): def get_or_create_search_model(): - search_model = SearchModel.objects.filter().first() + search_model = SearchModelConfig.objects.filter().first() if not search_model: - search_model = SearchModel.objects.create() + search_model = SearchModelConfig.objects.create() return search_model diff --git a/src/database/admin.py b/src/database/admin.py index a2aa85e2..8d2130ba 100644 --- a/src/database/admin.py +++ b/src/database/admin.py @@ -8,7 +8,7 @@ from database.models import ( ChatModelOptions, OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, - SearchModel, + SearchModelConfig, Subscription, ) @@ -17,5 +17,5 @@ admin.site.register(KhojUser, UserAdmin) admin.site.register(ChatModelOptions) admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig) -admin.site.register(SearchModel) +admin.site.register(SearchModelConfig) admin.site.register(Subscription) diff --git a/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py b/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py new file mode 100644 index 00000000..a8100370 --- /dev/null +++ b/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.5 on 2023-11-16 01:13 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0017_searchmodel"), + ] + + operations = [ + migrations.CreateModel( + name="SearchModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("name", models.CharField(default="default", max_length=200)), + ("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)), + ("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)), + ("cross_encoder", models.CharField(default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200)), + ], + options={ + "abstract": False, + }, + ), + migrations.DeleteModel( + name="SearchModel", + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 5571c5a7..92848e5c 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -102,16 +102,14 @@ class LocalPlaintextConfig(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class SearchModel(BaseModel): +class SearchModelConfig(BaseModel): class ModelType(models.TextChoices): TEXT = "text" name = models.CharField(max_length=200, default="default") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") - cross_encoder = models.CharField( - max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2", null=True, blank=True - ) + cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") class OpenAIProcessorConversationConfig(BaseModel): diff --git a/src/khoj/migrations/migrate_server_pg.py b/src/khoj/migrations/migrate_server_pg.py index 27226d9f..434e27d7 100644 --- a/src/khoj/migrations/migrate_server_pg.py +++ b/src/khoj/migrations/migrate_server_pg.py @@ -64,7 +64,7 @@ from database.models import ( OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, ChatModelOptions, - SearchModel, + SearchModelConfig, ) logger = logging.getLogger(__name__) @@ -87,12 +87,12 @@ def migrate_server_pg(args): if "search-type" in raw_config and raw_config["search-type"]: if "asymmetric" in raw_config["search-type"]: # Delete all existing search models - SearchModel.objects.filter(model_type=SearchModel.ModelType.TEXT).delete() + SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete() # Create new search model from existing Khoj YAML config asymmetric_search = raw_config["search-type"]["asymmetric"] - SearchModel.objects.create( + SearchModelConfig.objects.create( name="default", - model_type=SearchModel.ModelType.TEXT, + model_type=SearchModelConfig.ModelType.TEXT, bi_encoder=asymmetric_search.get("encoder"), cross_encoder=asymmetric_search.get("cross-encoder"), ) diff --git a/src/khoj/processor/text_to_entries.py b/src/khoj/processor/text_to_entries.py index 66a489eb..ac42105a 100644 --- a/src/khoj/processor/text_to_entries.py +++ b/src/khoj/processor/text_to_entries.py @@ -6,15 +6,15 @@ import logging import uuid from tqdm import tqdm from typing import Callable, List, Tuple, Set, Any +from khoj.utils import state from khoj.utils.helpers import is_none_or_empty, timer, batcher # Internal Packages from khoj.utils.rawconfig import Entry -from khoj.processor.embeddings import EmbeddingsModel from khoj.search_filter.date_filter import DateFilter from database.models import KhojUser, Entry as DbEntry, EntryDates -from database.adapters import EntryAdapters, get_or_create_search_model +from database.adapters import EntryAdapters logger = logging.getLogger(__name__) @@ -22,8 +22,7 @@ logger = logging.getLogger(__name__) class TextToEntries(ABC): def __init__(self, config: Any = None): - bi_encoder_name = get_or_create_search_model().bi_encoder - self.embeddings_model = EmbeddingsModel(bi_encoder_name) + self.embeddings_model = state.embeddings_model self.config = config self.date_filter = DateFilter() diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 9993d2ca..190fc260 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -581,7 +581,7 @@ async def chat( request: Request, q: str, n: Optional[int] = 5, - d: Optional[float] = 0.4, + d: Optional[float] = 0.15, client: Optional[str] = None, stream: Optional[bool] = False, user_agent: Optional[str] = Header(None), diff --git a/tests/helpers.py b/tests/helpers.py index bf30a80d..079eb475 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,7 @@ from database.models import ( ChatModelOptions, OfflineChatProcessorConversationConfig, OpenAIProcessorConversationConfig, - SearchModel, + SearchModelConfig, UserConversationConfig, Conversation, Subscription, @@ -74,7 +74,7 @@ class ConversationFactory(factory.django.DjangoModelFactory): class SearchModelFactory(factory.django.DjangoModelFactory): class Meta: - model = SearchModel + model = SearchModelConfig name = "default" model_type = "text"