Files
khoj/src/khoj/database/adapters/__init__.py
2024-08-05 19:57:21 +05:30

1279 lines
49 KiB
Python

import json
import logging
import math
import random
import re
import secrets
import sys
from datetime import date, datetime, timedelta, timezone
from enum import Enum
from typing import Callable, Iterable, List, Optional, Type
import cron_descriptor
from apscheduler.job import Job
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
from django.db import models
from django.db.models import Q
from django.db.models.manager import BaseManager
from django.db.utils import IntegrityError
from django_apscheduler.models import DjangoJob, DjangoJobExecution
from fastapi import HTTPException
from pgvector.django import CosineDistance
from torch import Tensor
from khoj.database.models import (
Agent,
ChatModelOptions,
ClientApplication,
Conversation,
DataStore,
Entry,
FileObject,
GithubConfig,
GithubRepoConfig,
GoogleUser,
KhojApiUser,
KhojUser,
NotionConfig,
OpenAIProcessorConversationConfig,
ProcessLock,
PublicConversation,
ReflectiveQuestion,
SearchModelConfig,
ServerChatSettings,
SpeechToTextModelOptions,
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import generate_random_name, is_none_or_empty, timer
logger = logging.getLogger(__name__)
class SubscriptionState(Enum):
TRIAL = "trial"
SUBSCRIBED = "subscribed"
UNSUBSCRIBED = "unsubscribed"
EXPIRED = "expired"
INVALID = "invalid"
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config:
notion_config = await NotionConfig.objects.acreate(token=token, user=user)
else:
notion_config.token = token
await notion_config.asave()
return notion_config
def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
name = name or f"{generate_random_name().title()}"
return KhojApiUser.objects.create(token=token, user=user, name=name)
async def acreate_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
name = name or f"{generate_random_name().title()}"
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
def get_khoj_tokens(user: KhojUser):
"Get all Khoj API keys for user"
return list(KhojApiUser.objects.filter(user=user))
async def delete_khoj_token(user: KhojUser, token: str):
"Delete Khoj API Key for user"
await KhojApiUser.objects.filter(token=token, user=user).adelete()
async def get_or_create_user(token: dict) -> KhojUser:
user = await get_user_by_token(token)
if not user:
user = await create_user_by_google_token(token)
return user
async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
user = await aget_user_by_phone_number(phone_number)
if not user:
user = await acreate_user_by_phone_number(phone_number)
return user
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
phone_number = phone_number.strip()
if not phone_number.startswith("+"):
phone_number = f"+{phone_number}"
existing_user_with_phone_number = await aget_user_by_phone_number(phone_number)
if existing_user_with_phone_number and existing_user_with_phone_number.id != user.id:
if is_none_or_empty(existing_user_with_phone_number.email):
# Transfer conversation history to the new user. If they don't have an associated email, they are effectively a new user
async for conversation in Conversation.objects.filter(user=existing_user_with_phone_number).aiterator():
conversation.user = user
await conversation.asave()
await existing_user_with_phone_number.adelete()
else:
raise HTTPException(status_code=400, detail="Phone number already exists")
user.phone_number = phone_number
await user.asave()
return user
async def aremove_phone_number(user: KhojUser) -> KhojUser:
user.phone_number = None
user.verified_phone_number = False
await user.asave()
return user
async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
user, _ = await KhojUser.objects.filter(phone_number=phone_number).aupdate_or_create(
defaults={"username": phone_number, "phone_number": phone_number}
)
await user.asave()
await Subscription.objects.acreate(user=user, type="trial")
return user
async def aget_or_create_user_by_email(email: str) -> KhojUser:
user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email})
await user.asave()
if user:
user.email_verification_code = secrets.token_urlsafe(18)
await user.asave()
user_subscription = await Subscription.objects.filter(user=user).afirst()
if not user_subscription:
await Subscription.objects.acreate(user=user, type="trial")
return user
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
user = await KhojUser.objects.filter(email_verification_code=code).afirst()
if not user:
return None
user.email_verification_code = None
user.verified_email = True
await user.asave()
return user
async def create_user_by_google_token(token: dict) -> KhojUser:
user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create(
defaults={"username": token.get("email"), "email": token.get("email")}
)
user.verified_email = True
await user.asave()
await GoogleUser.objects.acreate(
sub=token.get("sub"),
azp=token.get("azp"),
email=token.get("email"),
name=token.get("name"),
given_name=token.get("given_name"),
family_name=token.get("family_name"),
picture=token.get("picture"),
locale=token.get("locale"),
user=user,
)
await Subscription.objects.acreate(user=user, type="trial")
return user
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
user.first_name = first_name
user.last_name = last_name
user.save()
return user
def get_user_name(user: KhojUser):
full_name = user.get_full_name()
if not is_none_or_empty(full_name):
return full_name
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile:
return google_profile.given_name
return None
def get_user_photo(user: KhojUser):
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile:
return google_profile.picture
return None
def get_user_subscription(email: str) -> Optional[Subscription]:
return Subscription.objects.filter(user__email=email).first()
async def set_user_subscription(
email: str, is_recurring=None, renewal_date=None, type="standard"
) -> Optional[Subscription]:
# Get or create the user object and their subscription
user = await aget_or_create_user_by_email(email)
user_subscription = await Subscription.objects.filter(user=user).afirst()
# Update the user subscription state
user_subscription.type = type
if is_recurring is not None:
user_subscription.is_recurring = is_recurring
if renewal_date is False:
user_subscription.renewal_date = None
elif renewal_date is not None:
user_subscription.renewal_date = renewal_date
await user_subscription.asave()
return user_subscription
def subscription_to_state(subscription: Subscription) -> str:
if not subscription:
return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL:
# Trial subscription is valid for 7 days
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=14):
return SubscriptionState.EXPIRED.value
return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date is None:
return SubscriptionState.EXPIRED.value
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
return SubscriptionState.EXPIRED.value
return SubscriptionState.INVALID.value
def get_user_subscription_state(email: str) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = Subscription.objects.filter(user__email=email).first()
return subscription_to_state(user_subscription)
async def aget_user_subscription_state(user: KhojUser) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = await Subscription.objects.filter(user=user).afirst()
return subscription_to_state(user_subscription)
async def get_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst()
async def aget_user_by_uuid(uuid: str) -> KhojUser:
return await KhojUser.objects.filter(uuid=uuid).afirst()
async def get_user_by_token(token: dict) -> KhojUser:
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
if not google_user:
return None
return google_user.user
async def aget_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
matched_user = await KhojUser.objects.filter(phone_number=phone_number).prefetch_related("subscription").afirst()
if not matched_user:
return None
# If the user with this phone number does not have an email account with Khoj, return the user
if is_none_or_empty(matched_user.email):
return matched_user
# If the user has an email account with Khoj and a verified number, return the user
if matched_user.verified_phone_number:
return matched_user
return None
async def retrieve_user(session_id: str) -> KhojUser:
session = SessionStore(session_key=session_id)
if not await sync_to_async(session.exists)(session_key=session_id):
raise HTTPException(status_code=401, detail="Invalid session")
session_data = await sync_to_async(session.load)()
user = await KhojUser.objects.filter(id=session_data.get("_auth_user_id")).afirst()
if not user:
raise HTTPException(status_code=401, detail="Invalid user")
return user
def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all()
def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
return config
def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first()
return config
def delete_user_requests(window: timedelta = timedelta(days=1)):
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
async def aget_user_name(user: KhojUser):
full_name = user.get_full_name()
if not is_none_or_empty(full_name):
return full_name
google_profile: GoogleUser = await GoogleUser.objects.filter(user=user).afirst()
if google_profile:
return google_profile.given_name
return None
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
await object.objects.filter(user=user).adelete()
await object.objects.acreate(
input_files=deduped_files,
input_filter=deduped_filters,
index_heading_entries=updated_config.index_heading_entries,
user=user,
)
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
config = await GithubConfig.objects.filter(user=user).afirst()
if not config:
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
else:
config.pat_token = pat_token
await config.asave()
await config.githubrepoconfig.all().adelete()
for repo in repos:
await GithubRepoConfig.objects.acreate(
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
)
return config
def get_user_search_model_or_default(user=None):
if user and UserSearchModelConfig.objects.filter(user=user).exists():
return UserSearchModelConfig.objects.filter(user=user).first().setting
if SearchModelConfig.objects.filter(name="default").exists():
return SearchModelConfig.objects.filter(name="default").first()
else:
SearchModelConfig.objects.create()
return SearchModelConfig.objects.first()
def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
SearchModelConfig.objects.create()
search_models = SearchModelConfig.objects.all()
return search_models
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
async def aget_user_search_model(user: KhojUser):
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
class ProcessLockAdapters:
@staticmethod
def get_process_lock(process_name: str):
return ProcessLock.objects.filter(name=process_name).first()
@staticmethod
def set_process_lock(process_name: str, max_duration_in_seconds: int = 600):
return ProcessLock.objects.create(name=process_name, max_duration_in_seconds=max_duration_in_seconds)
@staticmethod
def is_process_locked(process_name: str):
process_lock = ProcessLock.objects.filter(name=process_name).first()
if not process_lock:
return False
if process_lock.started_at + timedelta(seconds=process_lock.max_duration_in_seconds) < datetime.now(
tz=timezone.utc
):
process_lock.delete()
logger.info(f"🔓 Deleted stale {process_name} process lock on timeout")
return False
return True
@staticmethod
def remove_process_lock(process_lock: ProcessLock):
return process_lock.delete()
@staticmethod
def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600, **kwargs):
# Exit early if process lock is already taken
if ProcessLockAdapters.is_process_locked(operation):
logger.debug(f"🔒 Skip executing {func} as {operation} lock is already taken")
return
success = False
process_lock = None
try:
# Set process lock
process_lock = ProcessLockAdapters.set_process_lock(operation, max_duration_in_seconds)
logger.info(f"🔐 Locked {operation} to execute {func}")
# Execute Function
with timer(f"🔒 Run {func} with {operation} process lock", logger):
func(**kwargs)
success = True
except IntegrityError as e:
logger.debug(f"⚠️ Unable to create the process lock for {func} with {operation}: {e}")
success = False
except Exception as e:
logger.error(f"🚨 Error executing {func} with {operation} process lock: {e}", exc_info=True)
success = False
finally:
# Remove Process Lock
if process_lock:
ProcessLockAdapters.remove_process_lock(process_lock)
logger.info(
f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}"
)
else:
logger.debug(f"Skip removing {operation} process lock as it was not set")
def run_with_process_lock(*args, **kwargs):
"""Wrapper function used for scheduling jobs.
Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own.
"""
return ProcessLockAdapters.run_with_lock(*args, **kwargs)
class ClientApplicationAdapters:
@staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str):
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()
class AgentAdapters:
DEFAULT_AGENT_NAME = "Khoj"
DEFAULT_AGENT_AVATAR = "https://assets.khoj.dev/lamp-128.png"
DEFAULT_AGENT_SLUG = "khoj"
@staticmethod
async def aget_agent_by_slug(agent_slug: str, user: KhojUser):
return await Agent.objects.filter(
(Q(slug__iexact=agent_slug.lower())) & (Q(public=True) | Q(creator=user))
).afirst()
@staticmethod
def get_agent_by_slug(slug: str, user: KhojUser = None):
if user:
return Agent.objects.filter((Q(slug__iexact=slug.lower())) & (Q(public=True) | Q(creator=user))).first()
return Agent.objects.filter(slug__iexact=slug.lower(), public=True).first()
@staticmethod
def get_all_accessible_agents(user: KhojUser = None):
if user:
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct().order_by("created_at")
return Agent.objects.filter(public=True).order_by("created_at")
@staticmethod
async def aget_all_accessible_agents(user: KhojUser = None) -> List[Agent]:
agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user)
return await sync_to_async(list)(agents)
@staticmethod
def get_conversation_agent_by_id(agent_id: int):
agent = Agent.objects.filter(id=agent_id).first()
if agent == AgentAdapters.get_default_agent():
# If the agent is set to the default agent, then return None and let the default application code be used
return None
return agent
@staticmethod
def get_default_agent():
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
@staticmethod
def create_default_agent():
default_conversation_config = ConversationAdapters.get_default_conversation_config()
if default_conversation_config is None:
logger.info("No default conversation config found, skipping default agent creation")
return None
default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder")
agent = Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
if agent:
agent.personality = default_personality
agent.chat_model = default_conversation_config
agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
agent.name = AgentAdapters.DEFAULT_AGENT_NAME
agent.save()
else:
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
agent = Agent.objects.create(
name=AgentAdapters.DEFAULT_AGENT_NAME,
public=True,
managed_by_admin=True,
chat_model=default_conversation_config,
personality=default_personality,
tools=["*"],
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
)
Conversation.objects.filter(agent=None).update(agent=agent)
return agent
@staticmethod
async def aget_default_agent():
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
class PublicConversationAdapters:
@staticmethod
def get_public_conversation_by_slug(slug: str):
return PublicConversation.objects.filter(slug=slug).first()
@staticmethod
def get_public_conversation_url(public_conversation: PublicConversation):
# Public conversations are viewable by anyone, but not editable.
return f"/share/chat/{public_conversation.slug}/"
class DataStoreAdapters:
@staticmethod
async def astore_data(data: dict, key: str, user: KhojUser, private: bool = True):
if await DataStore.objects.filter(key=key).aexists():
return key
await DataStore.objects.acreate(value=data, key=key, owner=user, private=private)
return key
@staticmethod
async def aretrieve_public_data(key: str):
return await DataStore.objects.filter(key=key, private=False).afirst()
class ConversationAdapters:
@staticmethod
def make_public_conversation_copy(conversation: Conversation):
return PublicConversation.objects.create(
source_owner=conversation.user,
agent=conversation.agent,
conversation_log=conversation.conversation_log,
slug=conversation.slug,
title=conversation.title,
)
@staticmethod
def get_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
) -> Optional[Conversation]:
if conversation_id:
conversation = (
Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
.order_by("-updated_at")
.first()
)
else:
agent = AgentAdapters.get_default_agent()
conversation = (
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
) or Conversation.objects.create(user=user, client=client_application, agent=agent)
return conversation
@staticmethod
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
return (
Conversation.objects.filter(user=user, client=client_application)
.prefetch_related("agent")
.order_by("-updated_at")
)
@staticmethod
async def aset_conversation_title(
user: KhojUser, client_application: ClientApplication, conversation_id: int, title: str
):
conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id
).afirst()
if conversation:
conversation.title = title
await conversation.asave()
return conversation
return None
@staticmethod
def get_conversation_by_id(conversation_id: int):
return Conversation.objects.filter(id=conversation_id).first()
@staticmethod
async def acreate_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None
):
if agent_slug:
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
if agent is None:
raise HTTPException(status_code=400, detail="No such agent currently exists.")
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent)
agent = await AgentAdapters.aget_default_agent()
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent)
@staticmethod
async def aget_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
) -> Optional[Conversation]:
query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent")
if conversation_id:
return await query.filter(id=conversation_id).afirst()
elif title:
return await query.filter(title=title).afirst()
conversation = await query.order_by("-updated_at").afirst()
return conversation or await Conversation.objects.prefetch_related("agent").acreate(
user=user, client=client_application
)
@staticmethod
async def adelete_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
):
if conversation_id:
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete()
return await Conversation.objects.filter(user=user, client=client_application).adelete()
@staticmethod
def has_any_conversation_config(user: KhojUser):
return ChatModelOptions.objects.filter(user=user).exists()
@staticmethod
def get_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().first()
@staticmethod
def has_valid_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().exists()
@staticmethod
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
if not config:
return None
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
@staticmethod
async def aset_user_voice_model(user: KhojUser, model_id: str):
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
if not config:
return None
new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
@staticmethod
def get_conversation_config(user: KhojUser):
config = UserConversationConfig.objects.filter(user=user).first()
if not config:
return None
return config.setting
@staticmethod
async def aget_conversation_config(user: KhojUser):
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
@staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if voice_model_config:
return voice_model_config.setting
return await VoiceModelOption.objects.afirst()
@staticmethod
def get_voice_model_options():
return VoiceModelOption.objects.all()
@staticmethod
def get_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").first()
if voice_model_config:
return voice_model_config.setting
return VoiceModelOption.objects.first()
@staticmethod
def get_default_conversation_config():
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.default_model is None:
return ChatModelOptions.objects.filter().first()
return server_chat_settings.default_model
@staticmethod
async def aget_default_conversation_config():
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related("default_model", "default_model__openai_config")
.afirst()
)
if server_chat_settings is None or server_chat_settings.default_model is None:
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.default_model
@staticmethod
async def aget_summarizer_conversation_config():
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related(
"summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config"
)
.afirst()
)
if server_chat_settings is None or (
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
):
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.summarizer_model or server_chat_settings.default_model
@staticmethod
def create_conversation_from_public_conversation(
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
):
scrubbed_title = public_conversation.title if public_conversation.title else public_conversation.slug
if scrubbed_title:
scrubbed_title = scrubbed_title.replace("-", " ")
return Conversation.objects.create(
user=user,
conversation_log=public_conversation.conversation_log,
client=client_app,
slug=scrubbed_title,
title=public_conversation.title,
agent=public_conversation.agent,
)
@staticmethod
def save_conversation(
user: KhojUser,
conversation_log: dict,
client_application: ClientApplication = None,
conversation_id: int = None,
user_message: str = None,
):
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
else:
conversation = (
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
)
if conversation:
conversation.conversation_log = conversation_log
conversation.slug = slug
conversation.updated_at = datetime.now(tz=timezone.utc)
conversation.save()
else:
Conversation.objects.create(
user=user, conversation_log=conversation_log, client=client_application, slug=slug
)
@staticmethod
def get_conversation_processor_options():
return ChatModelOptions.objects.all()
@staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions):
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
user_conversation_config.setting = new_config
user_conversation_config.save()
@staticmethod
async def aget_user_conversation_config(user: KhojUser):
config = (
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()
)
if not config:
return None
return config.setting
@staticmethod
async def get_speech_to_text_config():
return await SpeechToTextModelOptions.objects.filter().afirst()
@staticmethod
async def aget_conversation_starters(user: KhojUser, max_results=3):
all_questions = []
if await ReflectiveQuestion.objects.filter(user=user).aexists():
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
"question", flat=True
)
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=None).values_list)(
"question", flat=True
)
all_questions = await sync_to_async(list)(all_questions) # type: ignore
if len(all_questions) < max_results:
return all_questions
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
if agent and agent.chat_model:
conversation_config = conversation.agent.chat_model
else:
conversation_config = ConversationAdapters.get_conversation_config(user)
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return conversation_config
if (
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
) and conversation_config.openai_config:
return conversation_config
else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
@staticmethod
async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod
def get_text_to_image_model_config():
return TextToImageModelConfig.objects.filter().first()
@staticmethod
def get_text_to_image_model_options():
return TextToImageModelConfig.objects.all()
@staticmethod
def get_user_text_to_image_model_config(user: KhojUser):
config = UserTextToImageModelConfig.objects.filter(user=user).first()
if not config:
default_config = ConversationAdapters.get_text_to_image_model_config()
if not default_config:
return None
return default_config
return config.setting
@staticmethod
async def aget_user_text_to_image_model(user: KhojUser) -> Optional[TextToImageModelConfig]:
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
default_config = await ConversationAdapters.aget_text_to_image_model_config()
if not default_config:
return None
return default_config
return config.setting
@staticmethod
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
user=user, defaults={"setting": config}
)
return new_config
@staticmethod
def add_files_to_filter(user: KhojUser, conversation_id: int, files: List[str]):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
file_list = EntryAdapters.get_all_filenames_by_source(user, "computer")
for filename in files:
if filename in file_list and filename not in conversation.file_filters:
conversation.file_filters.append(filename)
conversation.save()
# remove files from conversation.file_filters that are not in file_list
conversation.file_filters = [file for file in conversation.file_filters if file in file_list]
conversation.save()
return conversation.file_filters
@staticmethod
def remove_files_from_filter(user: KhojUser, conversation_id: int, files: List[str]):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
for filename in files:
if filename in conversation.file_filters:
conversation.file_filters.remove(filename)
conversation.save()
# remove files from conversation.file_filters that are not in file_list
file_list = EntryAdapters.get_all_filenames_by_source(user, "computer")
conversation.file_filters = [file for file in conversation.file_filters if file in file_list]
conversation.save()
return conversation.file_filters
class FileObjectAdapters:
@staticmethod
def update_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text
file_object.save()
@staticmethod
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
def get_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod
def get_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).all()
@staticmethod
def delete_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).delete()
@staticmethod
def delete_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).delete()
@staticmethod
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text
await file_object.asave()
@staticmethod
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
async def async_get_file_objects_by_name(user: KhojUser, file_name: str):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name))
@staticmethod
async def async_get_all_file_objects(user: KhojUser):
return await sync_to_async(list)(FileObject.objects.filter(user=user))
@staticmethod
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
@staticmethod
async def async_delete_all_file_objects(user: KhojUser):
return await FileObject.objects.filter(user=user).adelete()
class EntryAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
date_filter = DateFilter()
@staticmethod
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
def delete_entry_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
queryset = Entry.objects.filter(user=user)
if file_type is not None:
queryset = queryset.filter(file_type=file_type)
if file_source is not None:
queryset = queryset.filter(file_source=file_source)
return queryset
@staticmethod
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while queryset.exists():
batch_ids = list(queryset.values_list("id", flat=True)[:batch_size])
batch = Entry.objects.filter(id__in=batch_ids, user=user)
count, _ = batch.delete()
deleted_count += count
return deleted_count
@staticmethod
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while await queryset.aexists():
batch_ids = await sync_to_async(list)(queryset.values_list("id", flat=True)[:batch_size])
batch = Entry.objects.filter(id__in=batch_ids, user=user)
count, _ = await batch.adelete()
deleted_count += count
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod
def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date):
return entry.filter(
entrydates__date__gte=start_date,
entrydates__date__lte=end_date,
)
@staticmethod
def user_has_entries(user: KhojUser):
return Entry.objects.filter(user=user).exists()
@staticmethod
async def auser_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists()
@staticmethod
async def adelete_entry_by_file(user: KhojUser, file_path: str):
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
deleted_count = 0
for i in range(0, len(filenames), batch_size):
batch = filenames[i : i + batch_size]
count, _ = await Entry.objects.filter(user=user, file_path__in=batch).adelete()
deleted_count += count
return deleted_count
@staticmethod
def get_all_filenames_by_source(user: KhojUser, file_source: str):
return (
Entry.objects.filter(user=user, file_source=file_source)
.distinct("file_path")
.values_list("file_path", flat=True)
)
@staticmethod
def get_size_of_indexed_data_in_mb(user: KhojUser):
entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
return total_size / 1024 / 1024
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()
explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query)
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Entry.objects.filter(user=user)
for term in explicit_word_terms:
if term.startswith("+"):
q_filter_terms &= Q(raw__icontains=term[1:])
elif term.startswith("-"):
q_filter_terms &= ~Q(raw__icontains=term[1:])
q_file_filter_terms = Q()
if len(file_filters) > 0:
for term in file_filters:
q_file_filter_terms |= Q(file_path__regex=term)
q_filter_terms &= q_file_filter_terms
if len(date_filters) > 0:
min_date, max_date = date_filters
if min_date is not None:
# Convert the min_date timestamp to yyyy-mm-dd format
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
if max_date is not None:
# Convert the max_date timestamp to yyyy-mm-dd format
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
relevant_entries = Entry.objects.filter(user=user).filter(
q_filter_terms,
)
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
return relevant_entries
@staticmethod
def search_with_embeddings(
user: KhojUser,
embeddings: Tensor,
max_results: int = 10,
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
):
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance")
return relevant_entries[:max_results]
@staticmethod
def get_unique_file_types(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod
def get_unique_file_sources(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()
class AutomationAdapters:
@staticmethod
def get_automations(user: KhojUser) -> Iterable[Job]:
all_automations: Iterable[Job] = state.scheduler.get_jobs()
for automation in all_automations:
if automation.id.startswith(f"automation_{user.uuid}_"):
yield automation
@staticmethod
def get_automation_metadata(user: KhojUser, automation: Job):
# Perform validation checks
# Check if user is allowed to delete this automation id
if not automation.id.startswith(f"automation_{user.uuid}_"):
raise ValueError("Invalid automation id")
automation_metadata = json.loads(automation.name)
crontime = automation_metadata["crontime"]
timezone = automation.next_run_time.strftime("%Z")
schedule = f"{cron_descriptor.get_description(crontime)} {timezone}"
return {
"id": automation.id,
"subject": automation_metadata["subject"],
"query_to_run": re.sub(r"^/automated_task\s*", "", automation_metadata["query_to_run"]),
"scheduling_request": automation_metadata["scheduling_request"],
"schedule": schedule,
"crontime": crontime,
"next": automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z"),
}
@staticmethod
def get_job_last_run(user: KhojUser, automation: Job):
# Perform validation checks
# Check if user is allowed to delete this automation id
if not automation.id.startswith(f"automation_{user.uuid}_"):
raise ValueError("Invalid automation id")
django_job = DjangoJob.objects.filter(id=automation.id).first()
execution = DjangoJobExecution.objects.filter(job=django_job, status="Executed")
last_run_time = None
if execution.exists():
last_run_time = execution.latest("run_time").run_time
return last_run_time.strftime("%Y-%m-%d %I:%M %p %Z") if last_run_time else None
@staticmethod
def get_automations_metadata(user: KhojUser):
for automation in AutomationAdapters.get_automations(user):
yield AutomationAdapters.get_automation_metadata(user, automation)
@staticmethod
def get_automation(user: KhojUser, automation_id: str) -> Job:
# Perform validation checks
# Check if user is allowed to delete this automation id
if not automation_id.startswith(f"automation_{user.uuid}_"):
raise ValueError("Invalid automation id")
# Check if automation with this id exist
automation: Job = state.scheduler.get_job(job_id=automation_id)
if not automation:
raise ValueError("Invalid automation id")
return automation
@staticmethod
def delete_automation(user: KhojUser, automation_id: str):
# Get valid, user-owned automation
automation: Job = AutomationAdapters.get_automation(user, automation_id)
# Collate info about user automation to be deleted
automation_metadata = AutomationAdapters.get_automation_metadata(user, automation)
automation.remove()
return automation_metadata