mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-03 21:29:08 +00:00
[Multi-User Part 1]: Enable storage of settings for plaintext files based on user account (#498)
- Partition configuration for indexing local data based on user accounts - Store indexed data in an underlying postgres db using the `pgvector` extension - Add migrations for all relevant user data and embeddings generation. Very little performance optimization has been done for the lookup time - Apply filters using SQL queries - Start removing many server-level configuration settings - Configure GitHub test actions to run during any PR. Update the test action to run in a containerized environment with a DB. - Update the Docker image and docker-compose.yml to work with the new application design
This commit is contained in:
@@ -1,15 +1,30 @@
|
||||
from typing import Type, TypeVar
|
||||
from typing import Type, TypeVar, List
|
||||
import uuid
|
||||
from datetime import date
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
from pgvector.django import CosineDistance
|
||||
from django.db.models.manager import BaseManager
|
||||
from django.db.models import Q
|
||||
from torch import Tensor
|
||||
|
||||
# Import sync_to_async from Django Channels
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from database.models import KhojUser, GoogleUser, NotionConfig
|
||||
from database.models import (
|
||||
KhojUser,
|
||||
GoogleUser,
|
||||
NotionConfig,
|
||||
GithubConfig,
|
||||
Embeddings,
|
||||
GithubRepoConfig,
|
||||
)
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=models.Model)
|
||||
|
||||
@@ -40,9 +55,7 @@ async def get_or_create_user(token: dict) -> KhojUser:
|
||||
|
||||
async def create_google_user(token: dict) -> KhojUser:
|
||||
user_info = token.get("userinfo")
|
||||
user = await KhojUser.objects.acreate(
|
||||
username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4()
|
||||
)
|
||||
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
|
||||
await user.asave()
|
||||
await GoogleUser.objects.acreate(
|
||||
sub=user_info.get("sub"),
|
||||
@@ -76,3 +89,149 @@ async def retrieve_user(session_id: str) -> KhojUser:
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid user")
|
||||
return user
|
||||
|
||||
|
||||
def get_all_users() -> BaseManager[KhojUser]:
|
||||
return KhojUser.objects.all()
|
||||
|
||||
|
||||
def get_user_github_config(user: KhojUser):
|
||||
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
if not config:
|
||||
return None
|
||||
return config
|
||||
|
||||
|
||||
def get_user_notion_config(user: KhojUser):
|
||||
config = NotionConfig.objects.filter(user=user).first()
|
||||
if not config:
|
||||
return None
|
||||
return config
|
||||
|
||||
|
||||
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
|
||||
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
|
||||
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
|
||||
await object.objects.filter(user=user).adelete()
|
||||
await object.objects.acreate(
|
||||
input_files=deduped_files,
|
||||
input_filter=deduped_filters,
|
||||
index_heading_entries=updated_config.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||
config = await GithubConfig.objects.filter(user=user).afirst()
|
||||
|
||||
if not config:
|
||||
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
|
||||
else:
|
||||
config.pat_token = pat_token
|
||||
await config.asave()
|
||||
await config.githubrepoconfig.all().adelete()
|
||||
|
||||
for repo in repos:
|
||||
await GithubRepoConfig.objects.acreate(
|
||||
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
class EmbeddingsAdapters:
|
||||
word_filer = WordFilter()
|
||||
file_filter = FileFilter()
|
||||
date_filter = DateFilter()
|
||||
|
||||
@staticmethod
|
||||
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_file(user: KhojUser, file_path: str):
|
||||
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def delete_all_embeddings(user: KhojUser, file_type: str):
|
||||
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||
|
||||
@staticmethod
|
||||
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
|
||||
return embeddings.filter(
|
||||
embeddingsdates__date__gte=start_date,
|
||||
embeddingsdates__date__lte=end_date,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def user_has_embeddings(user: KhojUser):
|
||||
return await Embeddings.objects.filter(user=user).aexists()
|
||||
|
||||
@staticmethod
|
||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||
q_filter_terms = Q()
|
||||
|
||||
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
|
||||
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
|
||||
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
|
||||
|
||||
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||
return Embeddings.objects.filter(user=user)
|
||||
|
||||
for term in explicit_word_terms:
|
||||
if term.startswith("+"):
|
||||
q_filter_terms &= Q(raw__icontains=term[1:])
|
||||
elif term.startswith("-"):
|
||||
q_filter_terms &= ~Q(raw__icontains=term[1:])
|
||||
|
||||
q_file_filter_terms = Q()
|
||||
|
||||
if len(file_filters) > 0:
|
||||
for term in file_filters:
|
||||
q_file_filter_terms |= Q(file_path__regex=term)
|
||||
|
||||
q_filter_terms &= q_file_filter_terms
|
||||
|
||||
if len(date_filters) > 0:
|
||||
min_date, max_date = date_filters
|
||||
if min_date is not None:
|
||||
# Convert the min_date timestamp to yyyy-mm-dd format
|
||||
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
|
||||
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
|
||||
if max_date is not None:
|
||||
# Convert the max_date timestamp to yyyy-mm-dd format
|
||||
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
||||
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
||||
|
||||
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
|
||||
q_filter_terms,
|
||||
)
|
||||
if file_type_filter:
|
||||
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
||||
return relevant_embeddings
|
||||
|
||||
@staticmethod
|
||||
def search_with_embeddings(
|
||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
||||
):
|
||||
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
|
||||
distance=CosineDistance("embeddings", embeddings)
|
||||
)
|
||||
if file_type_filter:
|
||||
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
||||
relevant_embeddings = relevant_embeddings.order_by("distance")
|
||||
return relevant_embeddings[:max_results]
|
||||
|
||||
@staticmethod
|
||||
def get_unique_file_types(user: KhojUser):
|
||||
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||
|
||||
Reference in New Issue
Block a user