Files
khoj/src/database/adapters/__init__.py
sabaimran 609d358b1a Use sql datetime comparison for detecting validity of subscription renewal date
- Update the unsubscribe endpoint to use query params
- Use subscription id to process unsubscribe endpoint, rather than the customer id
2023-11-07 19:17:36 -08:00

421 lines
15 KiB
Python

from typing import Type, TypeVar, List
from datetime import date, datetime, timedelta
import secrets
from typing import Type, TypeVar, List
from datetime import date, timezone
from django.db import models
from django.contrib.sessions.backends.db import SessionStore
from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
# Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async
from fastapi import HTTPException
from database.models import (
KhojUser,
GoogleUser,
KhojApiUser,
NotionConfig,
GithubConfig,
Entry,
GithubRepoConfig,
Conversation,
ChatModelOptions,
UserConversationConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
)
from khoj.utils.helpers import generate_random_name
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter
ModelType = TypeVar("ModelType", bound=models.Model)
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
instance = await model_class.objects.filter(id=id).afirst()
if not instance:
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
return instance
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config:
notion_config = await NotionConfig.objects.acreate(token=token, user=user)
else:
notion_config.token = token
await notion_config.asave()
return notion_config
async def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
name = name or f"{generate_random_name().title()}"
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
await api_config.asave()
return api_config
def get_khoj_tokens(user: KhojUser):
"Get all Khoj API keys for user"
return list(KhojApiUser.objects.filter(user=user))
async def delete_khoj_token(user: KhojUser, token: str):
"Delete Khoj API Key for user"
await KhojApiUser.objects.filter(token=token, user=user).adelete()
async def get_or_create_user(token: dict) -> KhojUser:
user = await get_user_by_token(token)
if not user:
user = await create_google_user(token)
return user
async def create_google_user(token: dict) -> KhojUser:
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
await user.asave()
await GoogleUser.objects.acreate(
sub=token.get("sub"),
azp=token.get("azp"),
email=token.get("email"),
name=token.get("name"),
given_name=token.get("given_name"),
family_name=token.get("family_name"),
picture=token.get("picture"),
locale=token.get("locale"),
user=user,
)
return user
async def set_user_subscribed(email: str, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email).afirst()
if user:
user.subscription_type = type
start_date = user.subscription_renewal_date or datetime.now()
user.subscription_renewal_date = start_date + timedelta(days=30)
await user.asave()
return user
else:
return None
def is_user_subscribed(email: str, type="standard") -> bool:
return KhojUser.objects.filter(
email=email, subscription_type=type, subscription_renewal_date__gte=datetime.now(tz=timezone.utc)
).exists()
async def get_user_by_token(token: dict) -> KhojUser:
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
if not google_user:
return None
return google_user.user
async def retrieve_user(session_id: str) -> KhojUser:
session = SessionStore(session_key=session_id)
if not await sync_to_async(session.exists)(session_key=session_id):
raise HTTPException(status_code=401, detail="Invalid session")
session_data = await sync_to_async(session.load)()
user = await KhojUser.objects.filter(id=session_data.get("_auth_user_id")).afirst()
if not user:
raise HTTPException(status_code=401, detail="Invalid user")
return user
def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all()
def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
return config
def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first()
return config
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
await object.objects.filter(user=user).adelete()
await object.objects.acreate(
input_files=deduped_files,
input_filter=deduped_filters,
index_heading_entries=updated_config.index_heading_entries,
user=user,
)
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
config = await GithubConfig.objects.filter(user=user).afirst()
if not config:
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
else:
config.pat_token = pat_token
await config.asave()
await config.githubrepoconfig.all().adelete()
for repo in repos:
await GithubRepoConfig.objects.acreate(
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
)
return config
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
return conversation.first()
return Conversation.objects.create(user=user)
@staticmethod
async def aget_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if await conversation.aexists():
return await conversation.afirst()
return await Conversation.objects.acreate(user=user)
@staticmethod
def has_any_conversation_config(user: KhojUser):
return ChatModelOptions.objects.filter(user=user).exists()
@staticmethod
def get_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().first()
@staticmethod
def get_offline_chat_conversation_config():
return OfflineChatProcessorConversationConfig.objects.filter().first()
@staticmethod
def has_valid_offline_conversation_config():
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
@staticmethod
def has_valid_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().exists()
@staticmethod
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
if not config:
return None
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
@staticmethod
def get_conversation_config(user: KhojUser):
config = UserConversationConfig.objects.filter(user=user).first()
if not config:
return None
return config.setting
@staticmethod
def get_default_conversation_config():
return ChatModelOptions.objects.filter().first()
@staticmethod
def save_conversation(user: KhojUser, conversation_log: dict):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
conversation.update(conversation_log=conversation_log)
else:
Conversation.objects.create(user=user, conversation_log=conversation_log)
@staticmethod
def get_conversation_processor_options():
return ChatModelOptions.objects.all()
@staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions):
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
user_conversation_config.setting = new_config
user_conversation_config.save()
@staticmethod
def has_offline_chat():
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
@staticmethod
async def ahas_offline_chat():
return await OfflineChatProcessorConversationConfig.objects.filter(enabled=True).aexists()
@staticmethod
async def get_offline_chat():
return await ChatModelOptions.objects.filter(model_type="offline").afirst()
@staticmethod
async def aget_user_conversation_config(user: KhojUser):
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
@staticmethod
async def has_openai_chat():
return await OpenAIProcessorConversationConfig.objects.filter().aexists()
@staticmethod
async def get_openai_chat():
return await ChatModelOptions.objects.filter(model_type="openai").afirst()
@staticmethod
async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
async def aget_default_conversation_config():
return await ChatModelOptions.objects.filter().afirst()
class EntryAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
date_filter = DateFilter()
@staticmethod
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
def delete_entry_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
def delete_all_entries_by_type(user: KhojUser, file_type: str = None):
if file_type is None:
deleted_count, _ = Entry.objects.filter(user=user).delete()
else:
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
return deleted_count
@staticmethod
def delete_all_entries_by_source(user: KhojUser, file_source: str = None):
if file_source is None:
deleted_count, _ = Entry.objects.filter(user=user).delete()
else:
deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete()
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod
def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date):
return entry.filter(
entrydates__date__gte=start_date,
entrydates__date__lte=end_date,
)
@staticmethod
async def user_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists()
@staticmethod
async def adelete_entry_by_file(user: KhojUser, file_path: str):
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod
def aget_all_filenames_by_source(user: KhojUser, file_source: str):
return (
Entry.objects.filter(user=user, file_source=file_source)
.distinct("file_path")
.values_list("file_path", flat=True)
)
@staticmethod
async def adelete_all_entries(user: KhojUser):
return await Entry.objects.filter(user=user).adelete()
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()
explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query)
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Entry.objects.filter(user=user)
for term in explicit_word_terms:
if term.startswith("+"):
q_filter_terms &= Q(raw__icontains=term[1:])
elif term.startswith("-"):
q_filter_terms &= ~Q(raw__icontains=term[1:])
q_file_filter_terms = Q()
if len(file_filters) > 0:
for term in file_filters:
q_file_filter_terms |= Q(file_path__regex=term)
q_filter_terms &= q_file_filter_terms
if len(date_filters) > 0:
min_date, max_date = date_filters
if min_date is not None:
# Convert the min_date timestamp to yyyy-mm-dd format
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
if max_date is not None:
# Convert the max_date timestamp to yyyy-mm-dd format
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
relevant_entries = Entry.objects.filter(user=user).filter(
q_filter_terms,
)
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
return relevant_entries
@staticmethod
def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
):
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance")
return relevant_entries[:max_results]
@staticmethod
def get_unique_file_types(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod
def get_unique_file_source(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct()