mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Add a model that specifies the user's search model configuration
- Update all endpoints that generate embeddings to use the new model. Incl. generating text embeddings, creating embeddings for a search query
This commit is contained in:
@@ -32,6 +32,7 @@ from khoj.database.models import (
|
|||||||
SpeechToTextModelOptions,
|
SpeechToTextModelOptions,
|
||||||
Subscription,
|
Subscription,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
|
UserSearchModelConfig,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
ReflectiveQuestion,
|
ReflectiveQuestion,
|
||||||
@@ -250,7 +251,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_default_search_model():
|
def get_user_search_model_or_default(user=None):
|
||||||
|
if user and UserSearchModelConfig.objects.filter(user=user).exists():
|
||||||
|
return UserSearchModelConfig.objects.filter(user=user).first().setting
|
||||||
|
|
||||||
if SearchModelConfig.objects.filter(name="default").exists():
|
if SearchModelConfig.objects.filter(name="default").exists():
|
||||||
return SearchModelConfig.objects.filter(name="default").first()
|
return SearchModelConfig.objects.filter(name="default").first()
|
||||||
return SearchModelConfig.objects.first()
|
return SearchModelConfig.objects.first()
|
||||||
|
|||||||
33
src/khoj/database/migrations/0023_usersearchmodelconfig.py
Normal file
33
src/khoj/database/migrations/0023_usersearchmodelconfig.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# Generated by Django 4.2.7 on 2023-12-19 15:44
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0022_texttoimagemodelconfig"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="UserSearchModelConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
(
|
||||||
|
"setting",
|
||||||
|
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.searchmodelconfig"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user",
|
||||||
|
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
|||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates
|
from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||||
from khoj.database.adapters import EntryAdapters, get_default_search_model
|
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -112,7 +112,7 @@ class TextToEntries(ABC):
|
|||||||
with timer("Generated embeddings for entries to add to database in", logger):
|
with timer("Generated embeddings for entries to add to database in", logger):
|
||||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
model = get_default_search_model()
|
model = get_user_search_model_or_default(user)
|
||||||
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
||||||
|
|
||||||
added_entries: list[DbEntry] = []
|
added_entries: list[DbEntry] = []
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from starlette.authentication import requires
|
|||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_search_model_or_default
|
||||||
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
@@ -416,7 +416,7 @@ async def search(
|
|||||||
]
|
]
|
||||||
if text_search_models:
|
if text_search_models:
|
||||||
with timer("Encoding query took", logger=logger):
|
with timer("Encoding query took", logger=logger):
|
||||||
search_model = await sync_to_async(get_default_search_model)()
|
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from khoj.utils.state import SearchType
|
|||||||
from khoj.utils.rawconfig import SearchResponse, Entry
|
from khoj.utils.rawconfig import SearchResponse, Entry
|
||||||
from khoj.utils.jsonl import load_jsonl
|
from khoj.utils.jsonl import load_jsonl
|
||||||
from khoj.processor.content.text_to_entries import TextToEntries
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
from khoj.database.adapters import EntryAdapters, get_default_search_model
|
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||||
from khoj.database.models import KhojUser, Entry as DbEntry
|
from khoj.database.models import KhojUser, Entry as DbEntry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -115,7 +115,7 @@ async def query(
|
|||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
if question_embedding is None:
|
if question_embedding is None:
|
||||||
with timer("Query Encode Time", logger, state.device):
|
with timer("Query Encode Time", logger, state.device):
|
||||||
search_model = await sync_to_async(get_default_search_model)()
|
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||||
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
|
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
|
|||||||
Reference in New Issue
Block a user