Merge with master

This commit is contained in:
sabaimran
2023-11-15 18:34:46 -08:00
8 changed files with 49 additions and 22 deletions

View File

@@ -31,7 +31,7 @@ from database.models import (
GithubRepoConfig, GithubRepoConfig,
Conversation, Conversation,
ChatModelOptions, ChatModelOptions,
SearchModel, SearchModelConfig,
Subscription, Subscription,
UserConversationConfig, UserConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
@@ -73,11 +73,11 @@ async def delete_khoj_token(user: KhojUser, token: str):
async def get_or_create_user(token: dict) -> KhojUser: async def get_or_create_user(token: dict) -> KhojUser:
user = await get_user_by_token(token) user = await get_user_by_token(token)
if not user: if not user:
user = await create_user_by_token(token) user = await create_user_by_google_token(token)
return user return user
async def create_user_by_token(token: dict) -> KhojUser: async def create_user_by_google_token(token: dict) -> KhojUser:
user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create( user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create(
defaults={"username": token.get("email"), "email": token.get("email")} defaults={"username": token.get("email"), "email": token.get("email")}
) )
@@ -214,9 +214,9 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
def get_or_create_search_model(): def get_or_create_search_model():
search_model = SearchModel.objects.filter().first() search_model = SearchModelConfig.objects.filter().first()
if not search_model: if not search_model:
search_model = SearchModel.objects.create() search_model = SearchModelConfig.objects.create()
return search_model return search_model

View File

@@ -8,7 +8,7 @@ from database.models import (
ChatModelOptions, ChatModelOptions,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
SearchModel, SearchModelConfig,
Subscription, Subscription,
) )
@@ -17,5 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions) admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModel) admin.site.register(SearchModelConfig)
admin.site.register(Subscription) admin.site.register(Subscription)

View File

@@ -0,0 +1,30 @@
# Generated by Django 4.2.5 on 2023-11-16 01:13
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0017_searchmodel"),
]
operations = [
migrations.CreateModel(
name="SearchModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(default="default", max_length=200)),
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
("cross_encoder", models.CharField(default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200)),
],
options={
"abstract": False,
},
),
migrations.DeleteModel(
name="SearchModel",
),
]

View File

@@ -102,16 +102,14 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class SearchModel(BaseModel): class SearchModelConfig(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
TEXT = "text" TEXT = "text"
name = models.CharField(max_length=200, default="default") name = models.CharField(max_length=200, default="default")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
cross_encoder = models.CharField( cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2", null=True, blank=True
)
class OpenAIProcessorConversationConfig(BaseModel): class OpenAIProcessorConversationConfig(BaseModel):

View File

@@ -64,7 +64,7 @@ from database.models import (
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
ChatModelOptions, ChatModelOptions,
SearchModel, SearchModelConfig,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -87,12 +87,12 @@ def migrate_server_pg(args):
if "search-type" in raw_config and raw_config["search-type"]: if "search-type" in raw_config and raw_config["search-type"]:
if "asymmetric" in raw_config["search-type"]: if "asymmetric" in raw_config["search-type"]:
# Delete all existing search models # Delete all existing search models
SearchModel.objects.filter(model_type=SearchModel.ModelType.TEXT).delete() SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete()
# Create new search model from existing Khoj YAML config # Create new search model from existing Khoj YAML config
asymmetric_search = raw_config["search-type"]["asymmetric"] asymmetric_search = raw_config["search-type"]["asymmetric"]
SearchModel.objects.create( SearchModelConfig.objects.create(
name="default", name="default",
model_type=SearchModel.ModelType.TEXT, model_type=SearchModelConfig.ModelType.TEXT,
bi_encoder=asymmetric_search.get("encoder"), bi_encoder=asymmetric_search.get("encoder"),
cross_encoder=asymmetric_search.get("cross-encoder"), cross_encoder=asymmetric_search.get("cross-encoder"),
) )

View File

@@ -6,15 +6,15 @@ import logging
import uuid import uuid
from tqdm import tqdm from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any from typing import Callable, List, Tuple, Set, Any
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, timer, batcher from khoj.utils.helpers import is_none_or_empty, timer, batcher
# Internal Packages # Internal Packages
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Entry as DbEntry, EntryDates from database.models import KhojUser, Entry as DbEntry, EntryDates
from database.adapters import EntryAdapters, get_or_create_search_model from database.adapters import EntryAdapters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
class TextToEntries(ABC): class TextToEntries(ABC):
def __init__(self, config: Any = None): def __init__(self, config: Any = None):
bi_encoder_name = get_or_create_search_model().bi_encoder self.embeddings_model = state.embeddings_model
self.embeddings_model = EmbeddingsModel(bi_encoder_name)
self.config = config self.config = config
self.date_filter = DateFilter() self.date_filter = DateFilter()

View File

@@ -581,7 +581,7 @@ async def chat(
request: Request, request: Request,
q: str, q: str,
n: Optional[int] = 5, n: Optional[int] = 5,
d: Optional[float] = 0.4, d: Optional[float] = 0.15,
client: Optional[str] = None, client: Optional[str] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None), user_agent: Optional[str] = Header(None),

View File

@@ -7,7 +7,7 @@ from database.models import (
ChatModelOptions, ChatModelOptions,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
SearchModel, SearchModelConfig,
UserConversationConfig, UserConversationConfig,
Conversation, Conversation,
Subscription, Subscription,
@@ -74,7 +74,7 @@ class ConversationFactory(factory.django.DjangoModelFactory):
class SearchModelFactory(factory.django.DjangoModelFactory): class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta: class Meta:
model = SearchModel model = SearchModelConfig
name = "default" name = "default"
model_type = "text" model_type = "text"