Files
khoj/src/khoj/routers/helpers.py

2114 lines
78 KiB
Python

import asyncio
import base64
import hashlib
import json
import logging
import math
import os
import re
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import partial
from random import random
from typing import (
Annotated,
Any,
AsyncGenerator,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
import cron_descriptor
import pytz
import requests
from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from pydantic import BaseModel
from starlette.authentication import has_required_scope
from starlette.requests import URL
from khoj.database import adapters
from khoj.database.adapters import (
LENGTH_OF_FREE_TRIAL,
AgentAdapters,
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
ais_user_subscribed,
create_khoj_token,
get_khoj_tokens,
get_user_name,
get_user_notion_config,
get_user_subscription_state,
run_with_process_lock,
)
from khoj.database.models import (
Agent,
ChatModelOptions,
ClientApplication,
Conversation,
GithubConfig,
KhojUser,
NotionConfig,
ProcessLock,
Subscription,
TextToImageModelConfig,
UserRequests,
)
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
anthropic_send_message_to_model,
converse_anthropic,
)
from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
gemini_send_message_to_model,
)
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
send_message_to_model_offline,
)
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import (
ChatEvent,
ThreadedGenerator,
clean_json,
construct_chat_history,
generate_chatml_messages_with_context,
save_to_conversation_log,
)
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
from khoj.routers.email import is_resend_enabled, send_task_email
from khoj.routers.twilio import is_twilio_enabled
from khoj.search_type import text_search
from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
LRU,
ConversationCommand,
get_file_type,
is_none_or_empty,
is_valid_url,
log_telemetry,
mode_descriptions_for_llm,
timer,
tool_descriptions_for_llm,
)
from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI")
def is_query_empty(query: str) -> bool:
return is_none_or_empty(query.strip())
def validate_conversation_config(user: KhojUser):
default_config = ConversationAdapters.get_default_conversation_config(user)
if default_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
if default_config.model_type == "openai" and not default_config.openai_config:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
async def is_ready_to_chat(user: KhojUser):
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if user_conversation_config == None:
user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size
if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True
if (
user_conversation_config
and (
user_conversation_config.model_type
in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.GOOGLE,
]
)
and user_conversation_config.openai_config
):
return True
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def get_file_content(file: UploadFile):
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding)
def update_telemetry_state(
request: Request,
telemetry_type: str,
api: str,
client: Optional[str] = None,
user_agent: Optional[str] = None,
referer: Optional[str] = None,
host: Optional[str] = None,
metadata: Optional[dict] = None,
):
user: KhojUser = request.user.object if request.user.is_authenticated else None
client_app: ClientApplication = request.user.client_app if request.user.is_authenticated else None
subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None
user_state = {
"client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
"server_id": str(user.uuid) if user else None,
"subscription_type": subscription.type if subscription else None,
"is_recurring": subscription.is_recurring if subscription else None,
"client_id": str(client_app.name) if client_app else "default",
}
if metadata:
user_state.update(metadata)
state.telemetry += [
log_telemetry(
telemetry_type=telemetry_type, api=api, client=client, app_config=state.config.app, properties=user_state
)
]
def get_next_url(request: Request) -> str:
"Construct next url relative to current domain from request"
next_url_param = urlparse(request.query_params.get("next", "/"))
next_path = "/" # default next path
# If relative path or absolute path to current domain
if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc:
# Use path in next query param
next_path = next_url_param.path
# Construct absolute url using current domain and next path from request
return urljoin(str(request.base_url).rstrip("/"), next_path)
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
if query.startswith("/notes"):
return ConversationCommand.Notes
elif query.startswith("/help"):
return ConversationCommand.Help
elif query.startswith("/general"):
return ConversationCommand.General
elif query.startswith("/online"):
return ConversationCommand.Online
elif query.startswith("/webpage"):
return ConversationCommand.Webpage
elif query.startswith("/image"):
return ConversationCommand.Image
elif query.startswith("/automated_task"):
return ConversationCommand.AutomatedTask
elif query.startswith("/summarize"):
return ConversationCommand.Summarize
elif query.startswith("/diagram"):
return ConversationCommand.Diagram
elif query.startswith("/code"):
return ConversationCommand.Code
elif query.startswith("/research"):
return ConversationCommand.Research
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
else:
return ConversationCommand.Default
async def agenerate_chat_response(*args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, generate_chat_response, *args)
def gather_raw_attached_files(
attached_files: Dict[str, str],
):
"""
Gather contextual data from the given (raw) files
"""
if len(attached_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
async def acreate_title_from_history(
user: KhojUser,
conversation: Conversation,
):
"""
Create a title from the given conversation history
"""
chat_history = construct_chat_history(conversation.conversation_log)
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
with timer("Chat actor: Generate title from conversation history", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.strip()
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
"""
Create a title from the given query
"""
title_generation_prompt = prompts.subject_generation.format(query=query)
with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.strip()
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None) -> Tuple[bool, str]:
"""
Check if the system prompt is safe to use
"""
safe_prompt_check = prompts.personality_prompt_safety_expert.format(prompt=system_prompt)
is_safe = True
reason = ""
with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(safe_prompt_check, user=user)
response = response.strip()
try:
response = json.loads(response)
is_safe = response.get("safe", "True") == "True"
if not is_safe:
reason = response.get("reason", "")
except Exception:
logger.error(f"Invalid response for checking safe prompt: {response}")
if not is_safe:
logger.error(f"Unsafe prompt: {system_prompt}. Reason: {reason}")
return is_safe, reason
async def aget_relevant_information_sources(
query: str,
conversation_history: dict,
is_task: bool,
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
attached_files: str = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
"""
tool_options = dict()
tool_options_str = ""
agent_tools = agent.input_tools if agent else []
for tool, description in tool_descriptions_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
query=query,
tools=tool_options_str,
chat_history=chat_history,
personality_context=personality_context,
)
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
relevant_tools_prompt,
response_type="json_object",
user=user,
attached_files=attached_files,
tracer=tracer,
)
try:
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["source"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
logger.error(f"Invalid response for determining relevant tools: {response}")
return tool_options
final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
for llm_suggested_tool in response:
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if llm_suggested_tool in tool_options.keys() and (
len(agent_tools) == 0 or llm_suggested_tool in agent_tools
):
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_tool))
if is_none_or_empty(final_response):
if len(agent_tools) == 0:
final_response = [ConversationCommand.Default]
else:
final_response = [ConversationCommand.General]
except Exception:
logger.error(f"Invalid response for determining relevant tools: {response}")
if len(agent_tools) == 0:
final_response = [ConversationCommand.Default]
else:
final_response = agent_tools
return final_response
async def aget_relevant_output_modes(
query: str,
conversation_history: dict,
is_task: bool = False,
user: KhojUser = None,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
"""
mode_options = dict()
mode_options_str = ""
output_modes = agent.output_modes if agent else []
for mode, description in mode_descriptions_for_llm.items():
# Do not allow tasks to schedule another task
if is_task and mode == ConversationCommand.Automation:
continue
mode_options[mode.value] = description
if len(output_modes) == 0 or mode.value in output_modes:
mode_options_str += f'- "{mode.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
query=query,
modes=mode_options_str,
chat_history=chat_history,
personality_context=personality_context,
)
with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
)
try:
response = clean_json(response)
response = json.loads(response)
if is_none_or_empty(response):
return ConversationCommand.Text
output_mode = response["output"]
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if output_mode in mode_options.keys() and (len(output_modes) == 0 or output_mode in output_modes):
# Check whether the tool exists as a valid ConversationCommand
return ConversationCommand(output_mode)
logger.error(f"Invalid output mode selected: {output_mode}. Defaulting to text.")
return ConversationCommand.Text
except Exception:
logger.error(f"Invalid response for determining output mode: {response}")
return ConversationCommand.Text
async def infer_webpage_urls(
q: str,
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
attached_files: str = None,
tracer: dict = {},
) -> List[str]:
"""
Infer webpage links from the given query
"""
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
online_queries_prompt = prompts.infer_webpages_to_read.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
username=username,
personality_context=personality_context,
)
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
attached_files=attached_files,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
try:
response = clean_json(response)
urls = json.loads(response)
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
if is_none_or_empty(valid_unique_urls):
raise ValueError(f"Invalid list of urls: {response}")
if len(valid_unique_urls) == 0:
logger.error(f"No valid URLs found in response: {response}")
return []
return list(valid_unique_urls)
except Exception:
raise ValueError(f"Invalid list of urls: {response}")
async def generate_online_subqueries(
q: str,
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
attached_files: str = None,
tracer: dict = {},
) -> List[str]:
"""
Generate subqueries from the given query
"""
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
username=username,
personality_context=personality_context,
)
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
attached_files=attached_files,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list
try:
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
"""
chat_history = construct_chat_history(conversation_history)
crontime_prompt = prompts.crontime_prompt.format(
query=q,
chat_history=chat_history,
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
)
# Validate that the response is a non-empty, JSON-serializable list
try:
raw_response = raw_response.strip()
response: Dict[str, str] = json.loads(raw_response)
if not response or not isinstance(response, Dict) or len(response) != 3:
raise AssertionError(f"Invalid response for scheduling query : {response}")
return response.get("crontime"), response.get("query"), response.get("subject")
except Exception:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
if is_none_or_empty(corpus) or is_none_or_empty(qs):
return None
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
extract_relevant_information = prompts.extract_relevant_information.format(
query=", ".join(qs),
corpus=corpus.strip(),
personality_context=personality_context,
)
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
tracer=tracer,
)
return response.strip()
async def extract_relevant_summary(
q: str,
corpus: str,
conversation_history: dict,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
chat_history = construct_chat_history(conversation_history)
extract_relevant_information = prompts.extract_relevant_summary.format(
query=q,
chat_history=chat_history,
corpus=corpus.strip(),
personality_context=personality_context,
)
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
user=user,
query_images=query_images,
tracer=tracer,
)
return response.strip()
async def generate_summary_from_files(
q: str,
user: KhojUser,
file_filters: List[str],
meta_log: dict,
query_images: List[str] = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
attached_files: str = None,
tracer: dict = {},
):
try:
file_objects = None
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
if (file_objects and len(file_objects) == 0 and not attached_files) or (
not file_objects and not attached_files
):
response_log = "Sorry, I couldn't find anything to summarize."
yield response_log
return
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
if attached_files:
contextual_data += f"\n\n{attached_files}"
if not q:
q = "Create a general summary of the file"
file_names = [file.file_name for file in file_objects]
file_names.extend(file_filters)
all_file_names = ""
for file_name in file_names:
all_file_names += f"- {file_name}\n"
async for result in send_status_func(f"**Constructing Summary Using:**\n{all_file_names}"):
yield {ChatEvent.STATUS: result}
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
yield str(response)
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
yield result
async def generate_excalidraw_diagram(
q: str,
conversation_history: Dict[str, Any],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
attached_files: str = None,
tracer: dict = {},
):
if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
yield {ChatEvent.STATUS: event}
better_diagram_description_prompt = await generate_better_diagram_description(
q=q,
conversation_history=conversation_history,
location_data=location_data,
note_references=note_references,
online_results=online_results,
query_images=query_images,
user=user,
agent=agent,
attached_files=attached_files,
tracer=tracer,
)
if send_status_func:
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
yield {ChatEvent.STATUS: event}
excalidraw_diagram_description = await generate_excalidraw_diagram_from_description(
q=better_diagram_description_prompt,
user=user,
agent=agent,
tracer=tracer,
)
yield better_diagram_description_prompt, excalidraw_diagram_description
async def generate_better_diagram_description(
q: str,
conversation_history: Dict[str, Any],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
attached_files: str = None,
tracer: dict = {},
) -> str:
"""
Generate a diagram description from the given query and context
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
location = f"{location_data}" if location_data else "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
chat_history = construct_chat_history(conversation_history)
simplified_online_results = {}
if online_results:
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format(
query=q,
chat_history=chat_history,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
attached_files=attached_files,
tracer=tracer,
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
return response
async def generate_excalidraw_diagram_from_description(
q: str,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
excalidraw_diagram_generation = prompts.excalidraw_diagram_generation_prompt.format(
personality_context=personality_context,
query=q,
)
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(
query=excalidraw_diagram_generation, user=user, tracer=tracer
)
raw_response = clean_json(raw_response)
response: Dict[str, str] = json.loads(raw_response)
if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
# TODO Some additional validation here that it's a valid Excalidraw diagram
raise AssertionError(f"Invalid response for improving diagram description: {response}")
return response
async def generate_better_image_prompt(
q: str,
conversation_history: str,
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
query_images: Optional[List[str]] = None,
user: KhojUser = None,
agent: Agent = None,
attached_files: str = "",
tracer: dict = {},
) -> str:
"""
Generate a better image prompt from the given query
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
location = f"{location_data}" if location_data else "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
simplified_online_results = {}
if online_results:
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
if model_type == TextToImageModelConfig.ModelType.OPENAI:
image_prompt = prompts.image_generation_improve_prompt_dalle.format(
query=q,
chat_history=conversation_history,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q,
chat_history=conversation_history,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, attached_files=attached_files, tracer=tracer
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
return response
async def send_message_to_model_wrapper(
query: str,
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
query_images: List[str] = None,
context: str = "",
attached_files: str = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
vision_available = True
if vision_available and query_images:
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")
subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model
max_tokens = (
conversation_config.subscribed_max_prompt_size
if subscribed and conversation_config.subscribed_max_prompt_size
else conversation_config.max_prompt_size
)
tokenizer = conversation_config.tokenizer
model_type = conversation_config.model_type
vision_available = conversation_config.vision_enabled
if model_type == ChatModelOptions.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
loaded_model=loaded_model,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
)
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
attached_files=attached_files,
)
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
attached_files=attached_files,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
attached_files=attached_files,
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
def send_message_to_model_wrapper_sync(
message: str,
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
attached_files: str = "",
tracer: dict = {},
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
if conversation_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
vision_available = conversation_config.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
loaded_model=loaded_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
)
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
)
openai_response = send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
return openai_response
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
)
return gemini_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
def generate_chat_response(
q: str,
meta_log: dict,
conversation: Conversation,
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,
client_application: ClientApplication = None,
conversation_id: str = None,
location_data: LocationData = None,
user_name: Optional[str] = None,
meta_research: str = "",
query_images: Optional[List[str]] = None,
train_of_thought: List[Any] = [],
attached_files: str = None,
raw_attached_files: List[FileAttachment] = None,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
query_to_run = q
if meta_research:
query_to_run = f"AI Research: {meta_research} {q}"
try:
partial_completion = partial(
save_to_conversation_log,
q,
user=user,
meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
inferred_queries=inferred_queries,
client_application=client_application,
conversation_id=conversation_id,
query_images=query_images,
train_of_thought=train_of_thought,
raw_attached_files=raw_attached_files,
tracer=tracer,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
vision_available = True
if conversation_config.model_type == "offline":
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
user_query=query_to_run,
references=compiled_references,
online_results=online_results,
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion,
conversation_commands=conversation_commands,
model=conversation_config.chat_model,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
attached_files=attached_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
chat_response = converse(
compiled_references,
query_to_run,
query_images=query_images,
online_results=online_results,
code_results=code_results,
conversation_log=meta_log,
model=chat_model,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
vision_available=vision_available,
attached_files=attached_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
chat_response = converse_anthropic(
compiled_references,
query_to_run,
query_images=query_images,
online_results=online_results,
code_results=code_results,
conversation_log=meta_log,
model=conversation_config.chat_model,
api_key=api_key,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
vision_available=vision_available,
attached_files=attached_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
chat_response = converse_gemini(
compiled_references,
query_to_run,
online_results,
code_results,
meta_log,
model=conversation_config.chat_model,
api_key=api_key,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
query_images=query_images,
vision_available=vision_available,
attached_files=attached_files,
tracer=tracer,
)
metadata.update({"chat_model": conversation_config.chat_model})
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
return chat_response, metadata
class DeleteMessageRequestBody(BaseModel):
conversation_id: str
turn_id: str
class FeedbackData(BaseModel):
uquery: str
kquery: str
sentiment: str
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests
self.subscribed_requests = subscribed_requests
self.window = window
self.slug = slug
def __call__(self, request: Request):
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
# Other systems handle authentication
if not request.user.is_authenticated:
return
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Remove requests outside of the time window
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
# Check if the user has exceeded the rate limit
if subscribed and count_requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests} requests in {self.window} seconds for user: {user}. Limit is {self.subscribed_requests} requests."
)
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
if not subscribed and count_requests >= self.requests:
if self.requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests} requests in {self.window} seconds for user: {user}. Limit is {self.subscribed_requests} requests."
)
raise HTTPException(
status_code=429,
detail="Slow down! Too Many Requests",
)
logger.info(
f"Rate limit: {count_requests} requests in {self.window} seconds for user: {user}. Limit is {self.subscribed_requests} requests."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
)
# Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug)
class ApiImageRateLimiter:
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
self.max_images = max_images
self.max_combined_size_mb = max_combined_size_mb
def __call__(self, request: Request, body: ChatRequestBody):
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
# Other systems handle authentication
if not request.user.is_authenticated:
return
if not body.images:
return
# Check number of images
if len(body.images) > self.max_images:
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
)
# Check total size of images
total_size_mb = 0.0
for image in body.images:
# Unquote the image in case it's URL encoded
image = unquote(image)
# Assuming the image is a base64 encoded string
# Remove the data:image/jpeg;base64, part if present
if "," in image:
image = image.split(",", 1)[1]
# Decode base64 to get the actual size
image_bytes = base64.b64decode(image)
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
)
class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
self.slug = slug
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Research]
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
if state.billing_enabled is False:
return
if not request.user.is_authenticated:
return
if conversation_command not in self.restricted_commands:
return
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Remove requests outside of the 24-hr time window
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24)
command_slug = f"{self.slug}_{conversation_command.value}"
count_requests = await UserRequests.objects.filter(
user=user, created_at__gte=cutoff, slug=command_slug
).acount()
if subscribed and count_requests >= self.subscribed_rate_limit:
logger.info(
f"Rate limit: {count_requests} requests in 24 hours for user: {user}. Limit is {self.subscribed_rate_limit} requests."
)
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
if not subscribed and count_requests >= self.trial_rate_limit:
raise HTTPException(
status_code=429,
detail=f"We're glad you're enjoying Khoj! You've exceeded your `/{conversation_command.value}` command usage limit for today. Subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
)
await UserRequests.objects.acreate(user=user, slug=command_slug)
return
class ApiIndexedDataLimiter:
def __init__(
self,
incoming_entries_size_limit: float,
subscribed_incoming_entries_size_limit: float,
total_entries_size_limit: float,
subscribed_total_entries_size_limit: float,
):
self.num_entries_size = incoming_entries_size_limit
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
self.total_entries_size_limit = total_entries_size_limit
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
def __call__(self, request: Request, files: List[UploadFile] = None):
if state.billing_enabled is False:
return
subscribed = has_required_scope(request, ["premium"])
incoming_data_size_mb = 0.0
deletion_file_names = set()
if not request.user.is_authenticated or not files:
return
user: KhojUser = request.user.object
for file in files:
if file.size == 0:
deletion_file_names.add(file.filename)
incoming_data_size_mb += file.size / 1024 / 1024
num_deleted_entries = 0
for file_path in deletion_file_names:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
class CommonQueryParamsClass:
def __init__(
self,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
self.client = client
self.user_agent = user_agent
self.referer = referer
self.host = host
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
"""
Decide whether to notify the user of the AI response.
Default to notifying the user for now.
"""
if any(is_none_or_empty(message) for message in [original_query, executed_query, ai_response]):
return False
to_notify_or_not = prompts.to_notify_or_not.format(
original_query=original_query,
executed_query=executed_query,
response=ai_response,
)
with timer("Chat actor: Decide to notify user of automation response", logger):
try:
# TODO Replace with async call so we don't have to maintain a sync version
response = send_message_to_model_wrapper_sync(to_notify_or_not, user)
should_notify_result = "no" not in response.lower()
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
return should_notify_result
except:
logger.warning(f"Fallback to notify user of automation response as failed to infer should notify or not.")
return True
def scheduled_chat(
query_to_run: str,
scheduling_request: str,
subject: str,
user: KhojUser,
calling_url: URL,
job_id: str = None,
conversation_id: str = None,
):
logger.info(f"Processing scheduled_chat: {query_to_run}")
if job_id:
# Get the job object and check whether the time is valid for it to run. This helps avoid race conditions that cause the same job to be run multiple times.
job = AutomationAdapters.get_automation(user, job_id)
last_run_time = AutomationAdapters.get_job_last_run(user, job)
# Convert last_run_time from %Y-%m-%d %I:%M %p %Z to datetime object
if last_run_time:
last_run_time = datetime.strptime(last_run_time, "%Y-%m-%d %I:%M %p %Z").replace(tzinfo=timezone.utc)
# If the last run time was within the last 6 hours, don't run it again. This helps avoid multithreading issues and rate limits.
if (datetime.now(timezone.utc) - last_run_time).total_seconds() < 21600:
logger.info(f"Skipping scheduled chat {job_id} as the next run time is in the future.")
return
# Extract relevant params from the original URL
scheme = "http" if not calling_url.is_secure else "https"
query_dict = parse_qs(calling_url.query)
# Pop the stream value from query_dict if it exists
query_dict.pop("stream", None)
# Replace the original scheduling query with the scheduled query
query_dict["q"] = [query_to_run]
# Replace the original conversation_id with the conversation_id
if conversation_id:
# encode the conversation_id to avoid any issues with special characters
query_dict["conversation_id"] = [quote(str(conversation_id))]
# Restructure the original query_dict into a valid JSON payload for the chat API
json_payload = {key: values[0] for key, values in query_dict.items()}
# Construct the URL to call the chat API with the scheduled query string
url = f"{scheme}://{calling_url.netloc}/api/chat?client=khoj"
# Construct the Headers for the chat API
headers = {"User-Agent": "Khoj", "Content-Type": "application/json"}
if not state.anonymous_mode:
# Add authorization request header in non-anonymous mode
token = get_khoj_tokens(user)
if is_none_or_empty(token):
token = create_khoj_token(user).token
else:
token = token[0].token
headers["Authorization"] = f"Bearer {token}"
# Call the chat API endpoint with authenticated user token and query
raw_response = requests.post(url, headers=headers, json=json_payload, allow_redirects=False)
# Handle redirect manually if necessary
if raw_response.status_code in [301, 302]:
redirect_url = raw_response.headers["Location"]
logger.info(f"Redirecting to {redirect_url}")
raw_response = requests.post(redirect_url, headers=headers, json=json_payload)
# Stop if the chat API call was not successful
if raw_response.status_code != 200:
logger.error(f"Failed to run schedule chat: {raw_response.text}, user: {user}, query: {query_to_run}")
return None
# Extract the AI response from the chat API response
cleaned_query = re.sub(r"^/automated_task\s*", "", query_to_run).strip()
is_image = False
if raw_response.headers.get("Content-Type") == "application/json":
response_map = raw_response.json()
ai_response = response_map.get("response") or response_map.get("image")
is_image = False
if type(ai_response) == dict:
is_image = ai_response.get("image") is not None
else:
ai_response = raw_response.text
# Notify user if the AI response is satisfactory
if should_notify(
original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user
):
if is_resend_enabled():
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
else:
return raw_response
async def create_automation(
q: str,
timezone: str,
user: KhojUser,
calling_url: URL,
meta_log: dict = {},
conversation_id: str = None,
tracer: dict = {},
):
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject
async def schedule_automation(
query_to_run: str,
subject: str,
crontime: str,
timezone: str,
scheduling_request: str,
user: KhojUser,
calling_url: URL,
conversation_id: str,
):
# Disable minute level automation recurrence
minute_value = crontime.split(" ")[0]
if not minute_value.isdigit():
# Run automation at some random minute (to distribute request load) instead of running every X minutes
crontime = " ".join([str(math.floor(random() * 60))] + crontime.split(" ")[1:])
user_timezone = pytz.timezone(timezone)
trigger = CronTrigger.from_crontab(crontime, user_timezone)
trigger.jitter = 60
# Generate id and metadata used by task scheduler and process locks for the task runs
job_metadata = json.dumps(
{
"query_to_run": query_to_run,
"scheduling_request": scheduling_request,
"subject": subject,
"crontime": crontime,
"conversation_id": str(conversation_id),
}
)
query_id = hashlib.md5(f"{query_to_run}_{crontime}".encode("utf-8")).hexdigest()
job_id = f"automation_{user.uuid}_{query_id}"
job = await sync_to_async(state.scheduler.add_job)(
run_with_process_lock,
trigger=trigger,
args=(
scheduled_chat,
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
),
kwargs={
"query_to_run": query_to_run,
"scheduling_request": scheduling_request,
"subject": subject,
"user": user,
"calling_url": calling_url,
"job_id": job_id,
"conversation_id": conversation_id,
},
id=job_id,
name=job_metadata,
max_instances=2, # Allow second instance to kill any previous instance with stale lock
)
return job
def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str):
# Display next run time in user timezone instead of UTC
schedule = f'{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime("%Z")}'
next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z")
# Remove /automated_task prefix from inferred_query
unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run)
# Create the automation response
automation_icon_url = f"/static/assets/icons/automation.svg"
return f"""
### ![]({automation_icon_url}) Created Automation
- Subject: **{subject}**
- Query to Run: "{unprefixed_query_to_run}"
- Schedule: `{schedule}`
- Next Run At: {next_run_time}
Manage your automations [here](/automations).
""".strip()
class MessageProcessor:
def __init__(self):
self.references = {}
self.raw_response = ""
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
try:
json_chunk = json.loads(raw_chunk)
if "type" not in json_chunk:
json_chunk = {"type": "message", "data": json_chunk}
return json_chunk
except json.JSONDecodeError:
return {"type": "message", "data": raw_chunk}
elif raw_chunk:
return {"type": "message", "data": raw_chunk}
return {"type": "", "data": ""}
def process_message_chunk(self, raw_chunk: str) -> None:
chunk = self.convert_message_chunk_to_json(raw_chunk)
if not chunk or not chunk["type"]:
return
chunk_type = ChatEvent(chunk["type"])
if chunk_type == ChatEvent.REFERENCES:
self.references = chunk["data"]
elif chunk_type == ChatEvent.MESSAGE:
chunk_data = chunk["data"]
if isinstance(chunk_data, dict):
self.raw_response = self.handle_json_response(chunk_data)
elif (
isinstance(chunk_data, str) and chunk_data.strip().startswith("{") and chunk_data.strip().endswith("}")
):
try:
json_data = json.loads(chunk_data.strip())
self.raw_response = self.handle_json_response(json_data)
except json.JSONDecodeError:
self.raw_response += chunk_data
else:
self.raw_response += chunk_data
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
if "image" in json_data or "details" in json_data:
return json_data
if "response" in json_data:
return json_data["response"]
return json_data
async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict[str, Any]:
processor = MessageProcessor()
event_delimiter = "␃🔚␗"
buffer = ""
async for chunk in response_iterator:
# Start buffering chunks until complete event is received
buffer += chunk
# Once the buffer contains a complete event
while event_delimiter in buffer:
# Extract the event from the buffer
event, buffer = buffer.split(event_delimiter, 1)
# Process the event
if event:
processor.process_message_chunk(event)
# Process any remaining data in the buffer
if buffer:
processor.process_message_chunk(buffer)
return {"response": processor.raw_response, "references": processor.references}
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
user_picture = request.session.get("user", {}).get("picture")
is_active = has_required_scope(request, ["premium"])
has_documents = EntryAdapters.user_has_entries(user=user)
if not is_detailed:
return {
"request": request,
"username": user.username if user else None,
"user_photo": user_picture,
"is_active": is_active,
"has_documents": has_documents,
"khoj_version": state.khoj_version,
}
user_subscription_state = get_user_subscription_state(user.email)
user_subscription = adapters.get_user_subscription(user.email)
subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date
else None
)
subscription_enabled_trial_at = (
user_subscription.enabled_trial_at.strftime("%d %b %Y")
if user_subscription and user_subscription.enabled_trial_at
else None
)
given_name = get_user_name(user)
enabled_content_sources_set = set(EntryAdapters.get_unique_file_sources(user))
enabled_content_sources = {
"computer": ("computer" in enabled_content_sources_set),
"github": ("github" in enabled_content_sources_set),
"notion": ("notion" in enabled_content_sources_set),
}
notion_oauth_url = get_notion_auth_url(user)
current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = ConversationAdapters.get_conversation_config(
user
) or ConversationAdapters.get_default_conversation_config(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models:
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id})
voice_models = ConversationAdapters.get_voice_model_options()
voice_model_options = list()
for voice_model in voice_models:
voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id})
if len(voice_model_options) == 0:
eleven_labs_enabled = False
else:
eleven_labs_enabled = is_eleven_labs_enabled()
selected_voice_model_config = ConversationAdapters.get_voice_model_config(user)
return {
"request": request,
# user info
"username": user.username if user else None,
"user_photo": user_picture,
"is_active": is_active,
"given_name": given_name,
"phone_number": str(user.phone_number) if user.phone_number else "",
"is_phone_number_verified": user.verified_phone_number,
# user content settings
"enabled_content_source": enabled_content_sources,
"has_documents": has_documents,
"notion_token": notion_token,
# user model settings
"chat_model_options": chat_model_options,
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
"paint_model_options": all_paint_model_options,
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
"voice_model_options": voice_model_options,
"selected_voice_model_config": selected_voice_model_config.model_id if selected_voice_model_config else None,
# user billing info
"subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date,
"subscription_enabled_trial_at": subscription_enabled_trial_at,
# server settings
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"billing_enabled": state.billing_enabled,
"is_eleven_labs_enabled": eleven_labs_enabled,
"is_twilio_enabled": is_twilio_enabled(),
"khoj_version": state.khoj_version,
"anonymous_mode": state.anonymous_mode,
"notion_oauth_url": notion_oauth_url,
"length_of_free_trial": LENGTH_OF_FREE_TRIAL,
}
def configure_content(
files: Optional[dict[str, dict[str, str]]],
regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All,
user: KhojUser = None,
) -> bool:
success = True
if t == None:
t = state.SearchType.All
if t is not None and t in [type.value for type in state.SearchType]:
t = state.SearchType(t)
if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
return False
search_type = t.value if t else None
no_documents = all([not files.get(file_type) for file_type in files])
if files is None:
logger.warning(f"🚨 No files to process for {search_type} search.")
return True
try:
# Initialize Org Notes Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
text_search.setup(
OrgToEntries,
files.get("org"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
success = False
try:
# Initialize Markdown Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Markdown.value) and files[
"markdown"
]:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
text_search.setup(
MarkdownToEntries,
files.get("markdown"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup markdown: {e}", exc_info=True)
success = False
try:
# Initialize PDF Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
text_search.setup(
PdfToEntries,
files.get("pdf"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup PDF: {e}", exc_info=True)
success = False
try:
# Initialize Plaintext Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Plaintext.value) and files[
"plaintext"
]:
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
text_search.setup(
PlaintextToEntries,
files.get("plaintext"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
success = False
try:
if no_documents:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if (
search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
) and github_config is not None:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
text_search.setup(
GithubToEntries,
None,
regenerate=regenerate,
user=user,
config=github_config,
)
except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
success = False
try:
if no_documents:
# Initialize Notion Search
notion_config = NotionConfig.objects.filter(user=user).first()
if (
search_type == state.SearchType.All.value or search_type == state.SearchType.Notion.value
) and notion_config:
logger.info("🔌 Setting up search for notion")
text_search.setup(
NotionToEntries,
None,
regenerate=regenerate,
user=user,
config=notion_config,
)
except Exception as e:
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
success = False
try:
# Initialize Image Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
"image"
]:
logger.info("🖼️ Setting up search for images")
# Extract Entries, Generate Image Embeddings
text_search.setup(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try:
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]:
logger.info("📄 Setting up search for docx")
text_search.setup(
DocxToEntries,
files.get("docx"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup docx: {e}", exc_info=True)
success = False
# Invalidate Query Cache
if user:
state.query_cache[user.uuid] = LRU()
return success
def get_notion_auth_url(user: KhojUser):
if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI:
return None
return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}"