Files
khoj/src/khoj/routers/helpers.py
Debanjum a79025ee93 Limit max queries allowed per doc search tool call. Improve prompt
Reduce usage of boolean operators like "hello OR bye OR see you" which
doesn't work and reduces search quality. They're trying to stuff the
search query with multiple different queries.
2025-08-09 12:29:35 -07:00

3151 lines
118 KiB
Python

import asyncio
import base64
import fnmatch
import hashlib
import json
import logging
import math
import os
import re
import time
from datetime import datetime, timedelta, timezone
from random import random
from typing import (
Annotated,
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
import cron_descriptor
import pyjson5
import pytz
import requests
from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from django.utils import timezone as django_timezone
from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
from pydantic import BaseModel, EmailStr, Field
from starlette.authentication import has_required_scope
from starlette.requests import URL
from khoj.database import adapters
from khoj.database.adapters import (
LENGTH_OF_FREE_TRIAL,
AgentAdapters,
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
aget_user_by_email,
create_khoj_token,
get_default_search_model,
get_khoj_tokens,
get_user_name,
get_user_notion_config,
get_user_subscription_state,
run_with_process_lock,
)
from khoj.database.models import (
Agent,
ChatMessageModel,
ChatModel,
ClientApplication,
Conversation,
GithubConfig,
KhojUser,
NotionConfig,
ProcessLock,
RateLimitRecord,
Subscription,
TextToImageModelConfig,
UserRequests,
)
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
anthropic_send_message_to_model,
converse_anthropic,
)
from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
gemini_send_message_to_model,
)
from khoj.processor.conversation.openai.gpt import (
converse_openai,
send_message_to_model,
)
from khoj.processor.conversation.utils import (
ChatEvent,
OperatorRun,
ResearchIteration,
ResponseWithThought,
clean_json,
clean_mermaidjs,
construct_chat_history,
construct_question_history,
defilter_query,
generate_chatml_messages_with_context,
)
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
from khoj.routers.email import is_resend_enabled, send_task_email
from khoj.routers.twilio import is_twilio_enabled
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import text_search
from khoj.utils import state
from khoj.utils.helpers import (
LRU,
ConversationCommand,
ToolDefinition,
get_file_type,
in_debug_mode,
is_none_or_empty,
is_operator_enabled,
is_valid_url,
log_telemetry,
mode_descriptions_for_llm,
timer,
tool_descriptions_for_llm,
)
from khoj.utils.rawconfig import (
ChatRequestBody,
FileAttachment,
FileData,
LocationData,
SearchResponse,
)
from khoj.utils.state import SearchType
logger = logging.getLogger(__name__)
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI")
def is_query_empty(query: str) -> bool:
return is_none_or_empty(query.strip())
def validate_chat_model(user: KhojUser):
default_chat_model = ConversationAdapters.get_default_chat_model(user)
if default_chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
if default_chat_model.model_type == "openai" and not default_chat_model.ai_model_api:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
async def is_ready_to_chat(user: KhojUser):
user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
if user_chat_model is None:
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
if (
user_chat_model
and (
user_chat_model.model_type
in [
ChatModel.ModelType.OPENAI,
ChatModel.ModelType.ANTHROPIC,
ChatModel.ModelType.GOOGLE,
]
)
and user_chat_model.ai_model_api
):
return True
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def get_file_content(file: UploadFile):
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding)
def update_telemetry_state(
request: Request,
telemetry_type: str,
api: str,
client: Optional[str] = None,
user_agent: Optional[str] = None,
referer: Optional[str] = None,
host: Optional[str] = None,
metadata: Optional[dict] = None,
):
user: KhojUser = request.user.object if request.user.is_authenticated else None
client_app: ClientApplication = request.user.client_app if request.user.is_authenticated else None
subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None
user_state = {
"client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
"server_id": str(user.uuid) if user else None,
"subscription_type": subscription.type if subscription else None,
"is_recurring": subscription.is_recurring if subscription else None,
"client_id": str(client_app.name) if client_app else "default",
}
if metadata:
user_state.update(metadata)
state.telemetry += [
log_telemetry(
telemetry_type=telemetry_type,
api=api,
client=client,
disable_telemetry_env=state.telemetry_disabled,
properties=user_state,
)
]
def get_next_url(request: Request) -> str:
"Construct next url relative to current domain from request"
next_url_param = urlparse(request.query_params.get("next", "/"))
next_path = "/" # default next path
# If relative path or absolute path to current domain
if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc:
# Use path in next query param
next_path = next_url_param.path
# Construct absolute url using current domain and next path from request
return urljoin(str(request.base_url).rstrip("/"), next_path)
def get_conversation_command(query: str) -> ConversationCommand:
if query.startswith("/notes"):
return ConversationCommand.Notes
elif query.startswith("/general"):
return ConversationCommand.General
elif query.startswith("/online"):
return ConversationCommand.Online
elif query.startswith("/webpage"):
return ConversationCommand.Webpage
elif query.startswith("/image"):
return ConversationCommand.Image
elif query.startswith("/automated_task"):
return ConversationCommand.AutomatedTask
elif query.startswith("/diagram"):
return ConversationCommand.Diagram
elif query.startswith("/code"):
return ConversationCommand.Code
elif query.startswith("/research"):
return ConversationCommand.Research
elif query.startswith("/operator") and is_operator_enabled():
return ConversationCommand.Operator
else:
return ConversationCommand.Default
def gather_raw_query_files(
query_files: Dict[str, str],
):
"""
Gather contextual data from the given (raw) files
"""
if len(query_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
async def acreate_title_from_history(
user: KhojUser,
conversation: Conversation,
):
"""
Create a title from the given conversation history
"""
chat_history = construct_chat_history(conversation.messages)
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
with timer("Chat actor: Generate title from conversation history", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.text.strip()
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
"""
Create a title from the given query
"""
title_generation_prompt = prompts.subject_generation.format(query=query)
with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.text.strip()
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]:
"""
Check if the system prompt is safe to use
"""
safe_prompt_check = (
prompts.personality_prompt_safety_expert.format(prompt=system_prompt)
if not lax
else prompts.personality_prompt_safety_expert_lax.format(prompt=system_prompt)
)
is_safe = True
reason = ""
class SafetyCheck(BaseModel):
safe: bool
reason: str
with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
)
response = response.text.strip()
try:
response = json.loads(clean_json(response))
is_safe = str(response.get("safe", "true")).lower() == "true"
if not is_safe:
reason = response.get("reason", "")
except Exception:
logger.error(f"Invalid response for checking safe prompt: {response}")
if not is_safe:
logger.error(f"Unsafe prompt: {system_prompt}. Reason: {reason}")
return is_safe, reason
async def aget_data_sources_and_output_format(
query: str,
chat_history: list[ChatMessageModel],
is_task: bool,
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> Dict[str, Any]:
"""
Given a query, determine which of the available data sources and output modes the agent should use to answer appropriately.
"""
source_options = dict()
source_options_str = ""
agent_sources = agent.input_tools if agent else []
user_has_entries = await EntryAdapters.auser_has_entries(user)
for source, description in tool_descriptions_for_llm.items():
# Skip showing Notes tool as an option if user has no entries
if source == ConversationCommand.Notes and not user_has_entries:
continue
if source == ConversationCommand.Operator and not is_operator_enabled():
continue
source_options[source.value] = description
if len(agent_sources) == 0 or source.value in agent_sources:
source_options_str += f'- "{source.value}": "{description}"\n'
output_options = dict()
output_options_str = ""
agent_outputs = agent.output_modes if agent else []
for output, description in mode_descriptions_for_llm.items():
output_options[output.value] = description
if len(agent_outputs) == 0 or output.value in agent_outputs:
output_options_str += f'- "{output.value}": "{description}"\n'
chat_history_str = construct_chat_history(chat_history, n=6)
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
relevant_tools_prompt = prompts.pick_relevant_tools.format(
query=query,
sources=source_options_str,
outputs=output_options_str,
chat_history=chat_history_str,
personality_context=personality_context,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
class PickTools(BaseModel):
source: List[str] = Field(..., min_items=1)
output: str
with timer("Chat actor: Infer information sources to refer", logger):
raw_response = await send_message_to_model_wrapper(
relevant_tools_prompt,
response_type="json_object",
response_schema=PickTools,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
try:
response = clean_json(raw_response.text)
response = json.loads(response)
chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()]
chosen_output = response.get("output", "text").strip() # Default to text output
if is_none_or_empty(chosen_sources) or not isinstance(chosen_sources, list):
raise ValueError(
f"Invalid response for determining relevant tools: {chosen_sources}. Raw Response: {response}"
)
output_mode = ConversationCommand.Text
# Verify selected output mode is enabled for the agent, as the LLM can sometimes get confused by the tool options.
if chosen_output in output_options.keys() and (len(agent_outputs) == 0 or chosen_output in agent_outputs):
# Ensure that the chosen output mode exists as a valid ConversationCommand
output_mode = ConversationCommand(chosen_output)
data_sources = []
# Verify selected data sources are enabled for the agent, as the LLM can sometimes get confused by the tool options.
for chosen_source in chosen_sources:
# Ensure that the chosen data source exists as a valid ConversationCommand
if chosen_source in source_options.keys() and (len(agent_sources) == 0 or chosen_source in agent_sources):
data_sources.append(ConversationCommand(chosen_source))
# Fallback to default sources if the inferred data sources are unset or invalid
if is_none_or_empty(data_sources):
if len(agent_sources) == 0:
data_sources = [ConversationCommand.Default]
else:
data_sources = [ConversationCommand.General]
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. Error: {e}", exc_info=True)
data_sources = agent_sources if len(agent_sources) > 0 else [ConversationCommand.Default]
output_mode = agent_outputs[0] if len(agent_outputs) > 0 else ConversationCommand.Text
return {"sources": data_sources, "output": output_mode}
async def infer_webpage_urls(
q: str,
max_webpages: int,
chat_history: List[ChatMessageModel],
location_data: LocationData,
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> List[str]:
"""
Infer webpage links from the given query
"""
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history_str = construct_chat_history(chat_history)
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
online_queries_prompt = prompts.infer_webpages_to_read.format(
query=q,
max_webpages=max_webpages,
chat_history=chat_history_str,
current_date=utc_date,
location=location,
username=username,
personality_context=personality_context,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
class WebpageUrls(BaseModel):
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
with timer("Chat actor: Infer webpage urls to read", logger):
raw_response = await send_message_to_model_wrapper(
online_queries_prompt,
query_images=query_images,
response_type="json_object",
response_schema=WebpageUrls,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
try:
response = clean_json(raw_response.text)
urls = json.loads(response)
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
if is_none_or_empty(valid_unique_urls):
raise ValueError(f"Invalid list of urls: {response}")
if len(valid_unique_urls) == 0:
logger.error(f"No valid URLs found in response: {response}")
return []
return list(valid_unique_urls)[:max_webpages]
except Exception:
raise ValueError(f"Invalid list of urls: {response}")
async def generate_online_subqueries(
q: str,
chat_history: List[ChatMessageModel],
location_data: LocationData,
user: KhojUser,
query_images: List[str] = None,
query_files: str = None,
max_queries: int = 3,
agent: Agent = None,
tracer: dict = {},
) -> Set[str]:
"""
Generate subqueries from the given query
"""
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history_str = construct_chat_history(chat_history)
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
query=q,
chat_history=chat_history_str,
max_queries=max_queries,
current_date=utc_date,
location=location,
username=username,
personality_context=personality_context,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
class OnlineQueries(BaseModel):
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
with timer("Chat actor: Generate online search subqueries", logger):
raw_response = await send_message_to_model_wrapper(
online_queries_prompt,
query_images=query_images,
response_type="json_object",
response_schema=OnlineQueries,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list
try:
response = clean_json(raw_response.text)
response = pyjson5.loads(response)
response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, set) or not response or len(response) == 0:
logger.error(
f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}"
)
return {q}
return response
except Exception:
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
return {q}
def schedule_query(
q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, str, str]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
"""
chat_history_str = construct_chat_history(chat_history)
crontime_prompt = prompts.crontime_prompt.format(
query=q,
chat_history=chat_history_str,
)
raw_response = send_message_to_model_wrapper_sync(
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
)
# Validate that the response is a non-empty, JSON-serializable list
try:
raw_response_text = raw_response.text
response: Dict[str, str] = json.loads(clean_json(raw_response_text))
if not response or not isinstance(response, Dict) or len(response) != 3:
raise AssertionError(f"Invalid response for scheduling query : {response}")
return response.get("crontime"), response.get("query"), response.get("subject")
except Exception:
raise AssertionError(f"Invalid response for scheduling query: {raw_response.text}")
async def aschedule_query(
q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, str, str]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
"""
chat_history_str = construct_chat_history(chat_history)
crontime_prompt = prompts.crontime_prompt.format(
query=q,
chat_history=chat_history_str,
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
)
# Validate that the response is a non-empty, JSON-serializable list
try:
raw_response = raw_response.text.strip()
response: Dict[str, str] = json.loads(clean_json(raw_response))
if not response or not isinstance(response, Dict) or len(response) != 3:
raise AssertionError(f"Invalid response for scheduling query : {response}")
return response.get("crontime"), response.get("query"), response.get("subject")
except Exception:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
if is_none_or_empty(corpus) or is_none_or_empty(qs):
return None
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
extract_relevant_information = prompts.extract_relevant_information.format(
query=", ".join(qs),
corpus=corpus.strip(),
personality_context=personality_context,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
return response.text.strip()
async def extract_relevant_summary(
q: str,
corpus: str,
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
chat_history_str = construct_chat_history(chat_history)
extract_relevant_information = prompts.extract_relevant_summary.format(
query=q,
chat_history=chat_history_str,
corpus=corpus.strip(),
personality_context=personality_context,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
user=user,
query_images=query_images,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
return response.text.strip()
async def generate_summary_from_files(
q: str,
user: KhojUser,
file_filters: List[str],
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {},
):
try:
file_objects = None
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent)
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
response_log = "Sorry, I couldn't find anything to summarize."
yield response_log
return
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
if query_files:
contextual_data += f"\n\n{query_files}"
if not q:
q = "Create a general summary of the file"
file_names = [file.file_name for file in file_objects]
file_names.extend(file_filters)
all_file_names = ""
for file_name in file_names:
all_file_names += f"- {file_name}\n"
async for result in send_status_func(f"**Constructing Summary Using:**\n{all_file_names}"):
yield {ChatEvent.STATUS: result}
response = await extract_relevant_summary(
q,
contextual_data,
chat_history=chat_history,
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
yield str(response)
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
yield result
async def generate_excalidraw_diagram(
q: str,
chat_history: List[ChatMessageModel],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {},
):
if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
yield {ChatEvent.STATUS: event}
better_diagram_description_prompt = await generate_better_diagram_description(
q=q,
chat_history=chat_history,
location_data=location_data,
note_references=note_references,
online_results=online_results,
query_images=query_images,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
if send_status_func:
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
yield {ChatEvent.STATUS: event}
try:
excalidraw_diagram_description = await generate_excalidraw_diagram_from_description(
q=better_diagram_description_prompt,
user=user,
agent=agent,
tracer=tracer,
)
except Exception as e:
logger.error(f"Error generating Excalidraw diagram for {user.email}: {e}", exc_info=True)
yield better_diagram_description_prompt, None
return
scratchpad = excalidraw_diagram_description.get("scratchpad")
inferred_queries = f"Instruction: {better_diagram_description_prompt}\n\nScratchpad: {scratchpad}"
yield inferred_queries, excalidraw_diagram_description.get("elements")
async def generate_better_diagram_description(
q: str,
chat_history: List[ChatMessageModel],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> str:
"""
Generate a diagram description from the given query and context
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
location = f"{location_data}" if location_data else "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
chat_history_str = construct_chat_history(chat_history)
simplified_online_results = {}
if online_results:
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format(
query=q,
chat_history=chat_history_str,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response = response.text.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
return response
async def generate_excalidraw_diagram_from_description(
q: str,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Dict[str, Any]:
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
excalidraw_diagram_generation = prompts.excalidraw_diagram_generation_prompt.format(
personality_context=personality_context,
query=q,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
)
raw_response_text = clean_json(raw_response.text)
try:
# Expect response to have `elements` and `scratchpad` keys
response: Dict[str, str] = json.loads(raw_response_text)
if (
not response
or not isinstance(response, Dict)
or not response.get("elements")
or not response.get("scratchpad")
):
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {response}")
except Exception:
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response_text}")
if not response or not isinstance(response["elements"], List) or not isinstance(response["elements"][0], Dict):
# TODO Some additional validation here that it's a valid Excalidraw diagram
raise AssertionError(f"Invalid response for improving diagram description: {response}")
return response
async def generate_mermaidjs_diagram(
q: str,
chat_history: List[ChatMessageModel],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {},
):
if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
yield {ChatEvent.STATUS: event}
better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description(
q=q,
chat_history=chat_history,
location_data=location_data,
note_references=note_references,
online_results=online_results,
query_images=query_images,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
if send_status_func:
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
yield {ChatEvent.STATUS: event}
mermaidjs_diagram_description = await generate_mermaidjs_diagram_from_description(
q=better_diagram_description_prompt,
user=user,
agent=agent,
tracer=tracer,
)
inferred_queries = f"Instruction: {better_diagram_description_prompt}"
yield inferred_queries, mermaidjs_diagram_description
async def generate_better_mermaidjs_diagram_description(
q: str,
chat_history: List[ChatMessageModel],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> str:
"""
Generate a diagram description from the given query and context
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
location = f"{location_data}" if location_data else "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
chat_history_str = construct_chat_history(chat_history)
simplified_online_results = {}
if online_results:
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format(
query=q,
chat_history=chat_history_str,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response_text = response.text.strip()
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
response_text = response_text[1:-1]
return response_text
async def generate_mermaidjs_diagram_from_description(
q: str,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
mermaidjs_diagram_generation = prompts.mermaid_js_diagram_generation_prompt.format(
personality_context=personality_context,
query=q,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate Mermaid.js diagram", logger):
raw_response = await send_message_to_model_wrapper(
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
)
return clean_mermaidjs(raw_response.text.strip())
async def generate_better_image_prompt(
q: str,
conversation_history: str,
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
query_images: Optional[List[str]] = None,
user: KhojUser = None,
agent: Agent = None,
query_files: str = "",
tracer: dict = {},
) -> str:
"""
Generate a better image prompt from the given query
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
location = f"{location_data}" if location_data else "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
simplified_online_results = {}
if online_results:
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
if model_type == TextToImageModelConfig.ModelType.OPENAI:
image_prompt = prompts.image_generation_improve_prompt_dalle.format(
query=q,
chat_history=conversation_history,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
elif model_type in [
TextToImageModelConfig.ModelType.STABILITYAI,
TextToImageModelConfig.ModelType.REPLICATE,
TextToImageModelConfig.ModelType.GOOGLE,
]:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q,
chat_history=conversation_history,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(
image_prompt,
query_images=query_images,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response_text = response.text.strip()
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
response_text = response_text[1:-1]
return response_text
async def search_documents(
q: str,
n: int,
d: float,
user: KhojUser,
chat_history: list[ChatMessageModel],
conversation_id: str,
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
):
# Initialize Variables
compiled_references: List[dict[str, str]] = []
inferred_queries: List[str] = []
agent_has_entries = False
if agent:
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
if (
ConversationCommand.Notes not in conversation_commands
and ConversationCommand.Default not in conversation_commands
and not agent_has_entries
):
yield compiled_references, inferred_queries, q
return
# If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base
should_limit_to_agent_knowledge = (
ConversationCommand.Notes not in conversation_commands
and ConversationCommand.Default not in conversation_commands
)
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
if not agent_has_entries:
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
yield compiled_references, inferred_queries, q
return
# Extract filter terms from user message
defiltered_query = defilter_query(q)
filters_in_query = q.replace(defiltered_query, "").strip()
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
if not conversation:
logger.error(f"Conversation with id {conversation_id} not found when extracting references.")
yield compiled_references, inferred_queries, defiltered_query
return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
if is_none_or_empty(filters_in_query):
logger.debug(f"Filters in query: {filters_in_query}")
personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""
# Infer search queries from user message
with timer("Extracting search queries took", logger):
inferred_queries = await extract_questions(
query=defiltered_query,
user=user,
personality_context=personality_context,
chat_history=chat_history,
location_data=location_data,
query_images=query_images,
query_files=query_files,
tracer=tracer,
)
# Collate search results as context for the LLM
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"):
yield {ChatEvent.STATUS: event}
for query in inferred_queries:
results = await execute_search(
user if not should_limit_to_agent_knowledge else None,
f"{query} {filters_in_query}",
n=n,
t=SearchType.All,
r=True,
max_distance=d,
dedupe=False,
agent=agent,
)
# Attach associated query to each search result
for item in results:
item.additional["query"] = query
search_results.append(item)
search_results = text_search.deduplicated_search_responses(search_results)
compiled_references = [
{
"query": item.additional["query"],
"compiled": item["entry"],
"file": item.additional["file"],
"uri": item.additional["uri"],
}
for item in search_results
]
yield compiled_references, inferred_queries, defiltered_query
async def extract_questions(
query: str,
user: KhojUser,
personality_context: str = "",
chat_history: List[ChatMessageModel] = [],
location_data: LocationData = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
max_queries: int = 5,
tracer: dict = {},
):
"""
Infer document search queries from user message and provided context
"""
# Shared context setup
location = f"{location_data}" if location_data else "N/A"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Date variables for prompt formatting
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
# Common prompt setup for API-based models (using Anthropic prompts for consistency)
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
system_prompt = prompts.extract_questions_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=yesterday,
location=location,
username=username,
personality_context=personality_context,
max_queries=max_queries,
)
prompt = prompts.extract_questions_user_message.format(text=query, chat_history=chat_history_str)
class DocumentQueries(BaseModel):
"""Choose semantic search queries to run on user documents."""
queries: List[str] = Field(
...,
min_length=1,
max_length=max_queries,
description="List of semantic search queries to run on user documents.",
)
raw_response = await send_message_to_model_wrapper(
system_message=system_prompt,
query=prompt,
query_images=query_images,
query_files=query_files,
response_type="json_object",
response_schema=DocumentQueries,
user=user,
tracer=tracer,
)
# Extract questions from the response
try:
response = clean_json(raw_response.text)
response = pyjson5.loads(response)
queries = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(queries, list) or not queries:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [query]
return queries
except Exception:
logger.warning("LLM returned invalid JSON. Falling back to using user message as search query.")
return [query]
async def execute_search(
user: KhojUser,
q: str,
n: Optional[int] = 5,
t: Optional[SearchType] = None,
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
agent: Optional[Agent] = None,
):
# Run validation checks
results: List[SearchResponse] = []
start_time = time.time()
# Ensure the agent, if present, is accessible by the user
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
return results
if q is None or q == "":
logger.warning("No query param (q) passed in API call to initiate search")
return results
# initialize variables
user_query = q.strip()
results_count = n or 5
t = t or state.SearchType.All
search_tasks = []
# return cached results, if available
if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]:
logger.debug("Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
# Encode query with filter terms removed
defiltered_query = user_query
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query)
encoded_asymmetric_query = None
if t.value != SearchType.Image.value:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
# Use asyncio to run searches in parallel
if t.value in [
SearchType.All.value,
SearchType.Org.value,
SearchType.Markdown.value,
SearchType.Github.value,
SearchType.Notion.value,
SearchType.Plaintext.value,
SearchType.Pdf.value,
]:
# query markdown notes
search_tasks.append(
text_search.query(
user_query,
user,
t,
question_embedding=encoded_asymmetric_query,
max_distance=max_distance,
agent=agent,
)
)
# Query across each requested content types in parallel
with timer("Query took", logger):
if search_tasks:
hits_list = await asyncio.gather(*search_tasks)
for hits in hits_list:
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)
# Sort results across all content types and take top results
results = text_search.rerank_and_sort_results(
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
)[:results_count]
# Cache results
if user:
state.query_cache[user.uuid][query_cache_key] = results
end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
return results
async def send_message_to_model_wrapper(
query: str,
system_message: str = "",
response_type: str = "text",
response_schema: BaseModel = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
user: KhojUser = None,
query_images: List[str] = None,
context: str = "",
query_files: str = None,
chat_history: list[ChatMessageModel] = [],
agent_chat_model: ChatModel = None,
tracer: dict = {},
):
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
chat_model = vision_enabled_config
vision_available = True
if vision_available and query_images:
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user)
chat_model_name = chat_model.name
tokenizer = chat_model.tokenizer
model_type = chat_model.model_type
vision_available = chat_model.vision_enabled
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
chat_history=chat_history,
model_name=chat_model_name,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
query_images=query_images,
model_type=model_type,
query_files=query_files,
)
if model_type == ChatModel.ModelType.OPENAI:
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModel.ModelType.ANTHROPIC:
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModel.ModelType.GOOGLE:
return gemini_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
def send_message_to_model_wrapper_sync(
message: str,
system_message: str = "",
response_type: str = "text",
response_schema: BaseModel = None,
user: KhojUser = None,
query_images: List[str] = None,
query_files: str = "",
chat_history: List[ChatMessageModel] = [],
tracer: dict = {},
):
chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
if chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
max_tokens = ConversationAdapters.get_max_context_size(chat_model, user)
chat_model_name = chat_model.name
model_type = chat_model.model_type
vision_available = chat_model.vision_enabled
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
chat_history=chat_history,
model_name=chat_model_name,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=model_type,
query_images=query_images,
query_files=query_files,
)
if model_type == ChatModel.ModelType.OPENAI:
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tracer=tracer,
)
elif model_type == ChatModel.ModelType.ANTHROPIC:
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModel.ModelType.GOOGLE:
return gemini_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
async def agenerate_chat_response(
q: str,
chat_history: List[ChatMessageModel],
conversation: Conversation,
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
operator_results: List[OperatorRun] = [],
research_results: List[ResearchIteration] = [],
user: KhojUser = None,
location_data: LocationData = None,
user_name: Optional[str] = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
raw_generated_files: List[FileAttachment] = [],
program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
tracer: dict = {},
) -> Tuple[AsyncGenerator[ResponseWithThought, None], Dict[str, str]]:
# Initialize Variables
chat_response_generator: AsyncGenerator[ResponseWithThought, None] = None
metadata = {}
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
try:
query_to_run = q
deepthought = False
if research_results:
compiled_research = "".join([r.summarizedResult for r in research_results if r.summarizedResult])
if compiled_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{compiled_research}\n</collected_research>"
compiled_references = []
online_results = {}
code_results = {}
operator_results = []
deepthought = True
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
chat_model = vision_enabled_config
vision_available = True
if chat_model.model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
chat_model_name = chat_model.name
chat_response_generator = converse_openai(
# Query
query_to_run,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
query_images=query_images,
query_files=query_files,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Model
model=chat_model_name,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_anthropic(
# Query
query_to_run,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
query_images=query_images,
query_files=query_files,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Model
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_gemini(
# Query
query_to_run,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
query_images=query_images,
query_files=query_files,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Model
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)
metadata.update({"chat_model": chat_model.name})
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Return the generator directly
return chat_response_generator, metadata
class DeleteMessageRequestBody(BaseModel):
conversation_id: str
turn_id: str
class FeedbackData(BaseModel):
uquery: str
kquery: str
sentiment: str
class MagicLinkForm(BaseModel):
email: EmailStr
class EmailAttemptRateLimiter:
"""Rate limiter for email attempts BEFORE get/create user with valid email address."""
def __init__(self, requests: int, window: int, slug: str):
self.requests = requests
self.window = window # Window in seconds
self.slug = slug
async def __call__(self, form: MagicLinkForm):
# Disable login rate limiting in debug mode
if in_debug_mode():
return
# Calculate the time window cutoff
cutoff = django_timezone.now() - timedelta(seconds=self.window)
# Count recent attempts for this email and slug
count = await RateLimitRecord.objects.filter(
identifier=form.email, slug=self.slug, created_at__gte=cutoff
).acount()
if count >= self.requests:
logger.warning(f"Email attempt rate limit exceeded for {form.email} (slug: {self.slug})")
raise HTTPException(
status_code=429, detail="Too many requests for your email address. Please wait before trying again."
)
# Record the current attempt
await RateLimitRecord.objects.acreate(identifier=form.email, slug=self.slug)
class EmailVerificationApiRateLimiter:
"""Rate limiter for actions AFTER user with valid email address is known to exist"""
def __init__(self, requests: int, window: int, slug: str):
self.requests = requests
self.window = window # Window in seconds
self.slug = slug
async def __call__(self, email: str = None):
# Disable login rate limiting in debug mode
if in_debug_mode():
return
user: KhojUser = await aget_user_by_email(email)
if not user:
raise HTTPException(status_code=404, detail="User not found.")
# Remove requests outside of the time window
cutoff = django_timezone.now() - timedelta(seconds=self.window)
count_requests = await UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).acount()
# Check if the user has exceeded the rate limit
if count_requests >= self.requests:
logger.warning(
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
)
raise HTTPException(status_code=429, detail="Ran out of login attempts. Please wait before trying again.")
# Add the current request to the db
await UserRequests.objects.acreate(user=user, slug=self.slug)
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests
self.subscribed_requests = subscribed_requests
self.window = window
self.slug = slug
def __call__(self, request: Request):
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
# Other systems handle authentication
if not request.user.is_authenticated:
return
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Remove requests outside of the time window
cutoff = django_timezone.now() - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
# Check if the user has exceeded the rate limit
if subscribed and count_requests >= self.subscribed_requests:
logger.info(
f"Rate limit ({self.slug}): {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
if not subscribed and count_requests >= self.requests:
if self.requests >= self.subscribed_requests:
logger.info(
f"Rate limit ({self.slug}): {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
logger.info(
f"Rate limit ({self.slug}): {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
)
# Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug)
async def check_websocket(self, websocket: WebSocket):
"""WebSocket-specific rate limiting method"""
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
return
user: KhojUser = websocket.scope["user"].object
subscribed = has_required_scope(websocket, ["premium"])
current_window = "today" if self.window == 60 * 60 * 24 else "now"
next_window = "tomorrow" if self.window == 60 * 60 * 24 else "in a bit"
common_message_prefix = f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for {current_window}."
# Remove requests outside of the time window
cutoff = django_timezone.now() - timedelta(seconds=self.window)
count_requests = await UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).acount()
# Check if the user has exceeded the rate limit
if subscribed and count_requests >= self.subscribed_requests:
logger.info(
f"Rate limit ({self.slug}): {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
)
raise HTTPException(
status_code=429,
detail=f"{common_message_prefix} But let's chat more {next_window}?",
)
if not subscribed and count_requests >= self.requests:
if self.requests >= self.subscribed_requests:
logger.info(
f"Rate limit ({self.slug}): {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail=f"{common_message_prefix} But let's chat more {next_window}?",
)
logger.info(
f"Rate limit ({self.slug}): {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail=f"{common_message_prefix} You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation {next_window}.",
)
# Add the current request to the cache
await UserRequests.objects.acreate(user=user, slug=self.slug)
class ApiImageRateLimiter:
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
self.max_images = max_images
self.max_combined_size_mb = max_combined_size_mb
def __call__(self, request: Request, body: ChatRequestBody):
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
# Other systems handle authentication
if not request.user.is_authenticated:
return
if not body.images:
return
# Check number of images
if len(body.images) > self.max_images:
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
)
# Check total size of images
total_size_mb = 0.0
for image in body.images:
# Unquote the image in case it's URL encoded
image = unquote(image)
# Assuming the image is a base64 encoded string
# Remove the data:image/jpeg;base64, part if present
if "," in image:
image = image.split(",", 1)[1]
# Decode base64 to get the actual size
image_bytes = base64.b64decode(image)
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
)
def check_websocket(self, websocket: WebSocket, body: ChatRequestBody):
"""WebSocket-specific image rate limiting method"""
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
return
if not body.images:
return
# Check number of images
if len(body.images) > self.max_images:
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
)
# Check total size of images
total_size_mb = 0.0
for image in body.images:
# Unquote the image in case it's URL encoded
image = unquote(image)
# Assuming the image is a base64 encoded string
# Remove the data:image/jpeg;base64, part if present
if "," in image:
image = image.split(",", 1)[1]
# Decode base64 to get the actual size
image_bytes = base64.b64decode(image)
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
)
class WebSocketConnectionManager:
"""Limit max open websockets per user."""
def __init__(self, trial_user_max_connections: int = 10, subscribed_user_max_connections: int = 10):
self.trial_user_max_connections = trial_user_max_connections
self.subscribed_user_max_connections = subscribed_user_max_connections
self.connection_slug_prefix = "ws_connection_"
# Set cleanup window to 24 hours for truly stale connections (e.g., server crashes)
self.cleanup_window = 86400 # 24 hours
async def can_connect(self, websocket: WebSocket) -> bool:
"""Check if user can establish a new WebSocket connection."""
# Cleanup very old connections (likely from server crashes)
user: KhojUser = websocket.scope["user"].object
subscribed = has_required_scope(websocket, ["premium"])
max_connections = self.subscribed_user_max_connections if subscribed else self.trial_user_max_connections
await self._cleanup_stale_connections(user)
# Count ALL connections for this user (not filtered by time)
active_connections = await UserRequests.objects.filter(
user=user, slug__startswith=self.connection_slug_prefix
).acount()
# Restrict max active connections per user in production
return active_connections < max_connections or state.anonymous_mode or in_debug_mode()
async def register_connection(self, user: KhojUser, connection_id: str) -> None:
"""Register a new WebSocket connection."""
await UserRequests.objects.acreate(user=user, slug=f"{self.connection_slug_prefix}{connection_id}")
async def unregister_connection(self, user: KhojUser, connection_id: str) -> None:
"""Remove a WebSocket connection record."""
await UserRequests.objects.filter(user=user, slug=f"{self.connection_slug_prefix}{connection_id}").adelete()
async def _cleanup_stale_connections(self, user: KhojUser) -> None:
"""Remove connection records older than cleanup window."""
cutoff = django_timezone.now() - timedelta(seconds=self.cleanup_window)
await UserRequests.objects.filter(
user=user, slug__startswith=self.connection_slug_prefix, created_at__lt=cutoff
).adelete()
class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
self.slug = slug
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Research]
async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand):
if state.billing_enabled is False:
return
if not request.user.is_authenticated:
return
if conversation_command not in self.restricted_commands:
return
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Remove requests outside of the 24-hr time window
cutoff = django_timezone.now() - timedelta(seconds=60 * 60 * 24)
command_slug = f"{self.slug}_{conversation_command.value}"
count_requests = await UserRequests.objects.filter(
user=user, created_at__gte=cutoff, slug=command_slug
).acount()
if subscribed and count_requests >= self.subscribed_rate_limit:
logger.info(
f"Rate limit: {count_requests}/{self.subscribed_rate_limit} requests not allowed in 24 hours for subscribed user: {user}."
)
raise HTTPException(
status_code=429,
detail=f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your `/{conversation_command.value}` command usage limit for today. Maybe we can talk about something else for today?",
)
if not subscribed and count_requests >= self.trial_rate_limit:
logger.info(
f"Rate limit: {count_requests}/{self.trial_rate_limit} requests not allowed in 24 hours for user: {user}."
)
raise HTTPException(
status_code=429,
detail=f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your `/{conversation_command.value}` command usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can talk about something else for today?",
)
await UserRequests.objects.acreate(user=user, slug=command_slug)
return
class ApiIndexedDataLimiter:
def __init__(
self,
incoming_entries_size_limit: float,
subscribed_incoming_entries_size_limit: float,
total_entries_size_limit: float,
subscribed_total_entries_size_limit: float,
):
self.num_entries_size = incoming_entries_size_limit
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
self.total_entries_size_limit = total_entries_size_limit
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
def __call__(self, request: Request, files: List[UploadFile] = None):
if state.billing_enabled is False:
return
subscribed = has_required_scope(request, ["premium"])
incoming_data_size_mb = 0.0
deletion_file_names = set()
if not request.user.is_authenticated or not files:
return
user: KhojUser = request.user.object
for file in files:
if file.size == 0:
deletion_file_names.add(file.filename)
incoming_data_size_mb += file.size / 1024 / 1024
num_deleted_entries = 0
for file_path in deletion_file_names:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
logger.info(
f"Data limit: {incoming_data_size_mb}MB incoming will exceed {self.subscribed_num_entries_size}MB allowed for subscribed user: {user}."
)
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
logger.info(
f"Data limit: {incoming_data_size_mb}MB incoming will exceed {self.num_entries_size}MB allowed for user: {user}."
)
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
logger.info(
f"Data limit: {incoming_data_size_mb}MB incoming + {user_size_data}MB existing will exceed {self.subscribed_total_entries_size}MB allowed for subscribed user: {user}."
)
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
logger.info(
f"Data limit: {incoming_data_size_mb}MB incoming + {user_size_data}MB existing will exceed {self.subscribed_total_entries_size}MB allowed for non subscribed user: {user}."
)
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
class CommonQueryParamsClass:
def __init__(
self,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
self.client = client
self.user_agent = user_agent
self.referer = referer
self.host = host
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
def format_automation_response(scheduling_request: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
"""
Format the AI response to send in automation email to user.
"""
name = get_user_name(user)
username = prompts.user_name.format(name=name) if name else ""
automation_format_prompt = prompts.automation_format_prompt.format(
original_query=scheduling_request,
executed_query=executed_query,
response=ai_response,
username=username,
)
with timer("Chat actor: Format automation response", logger):
raw_response = send_message_to_model_wrapper_sync(automation_format_prompt, user=user)
return raw_response.text if raw_response else None
def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
"""
Decide whether to notify the user of the AI response.
Default to notifying the user for now.
"""
if any(is_none_or_empty(message) for message in [original_query, executed_query, ai_response]):
return False
to_notify_or_not = prompts.to_notify_or_not.format(
original_query=original_query,
executed_query=executed_query,
response=ai_response,
)
with timer("Chat actor: Decide to notify user of automation response", logger):
try:
# TODO Replace with async call so we don't have to maintain a sync version
raw_response: ResponseWithThought = send_message_to_model_wrapper_sync(
to_notify_or_not, user=user, response_type="json_object"
)
response = json.loads(clean_json(raw_response.text))
should_notify_result = response["decision"] == "Yes"
reason = response.get("reason", "unknown")
logger.info(
f"Decided to {'not ' if not should_notify_result else ''}notify user of automation response because of reason: {reason}."
)
return should_notify_result
except Exception as e:
logger.warning(
f"Fallback to notify user of automation response as failed to infer should notify or not. {e}",
exc_info=True,
)
return True
def scheduled_chat(
query_to_run: str,
scheduling_request: str,
subject: str,
user: KhojUser,
calling_url: str | URL,
job_id: str = None,
conversation_id: str = None,
):
logger.info(f"Processing scheduled_chat: {query_to_run}")
if job_id:
# Get the job object and check whether the time is valid for it to run. This helps avoid race conditions that cause the same job to be run multiple times.
job = AutomationAdapters.get_automation(user, job_id)
last_run_time = AutomationAdapters.get_job_last_run(user, job)
# Convert last_run_time from %Y-%m-%d %I:%M %p %Z to datetime object
if last_run_time:
last_run_time = datetime.strptime(last_run_time, "%Y-%m-%d %I:%M %p %Z").replace(tzinfo=timezone.utc)
# If the last run time was within the last 6 hours, don't run it again. This helps avoid multithreading issues and rate limits.
if (datetime.now(timezone.utc) - last_run_time).total_seconds() < 6 * 60 * 60:
logger.info(f"Skipping scheduled chat {job_id} as the next run time is in the future.")
return
# Extract relevant params from the original URL
parsed_url = URL(calling_url) if isinstance(calling_url, str) else calling_url
scheme = "http" if not parsed_url.is_secure else "https"
query_dict = parse_qs(parsed_url.query)
# Pop the stream value from query_dict if it exists
query_dict.pop("stream", None)
# Replace the original scheduling query with the scheduled query
query_dict["q"] = [query_to_run]
# Replace the original conversation_id with the conversation_id
if conversation_id:
# encode the conversation_id to avoid any issues with special characters
query_dict["conversation_id"] = [quote(str(conversation_id))]
# validate that the conversation id exists. If not, delete the automation and exit.
if not ConversationAdapters.get_conversation_by_id(conversation_id):
AutomationAdapters.delete_automation(user, job_id)
return
# Restructure the original query_dict into a valid JSON payload for the chat API
json_payload = {key: values[0] for key, values in query_dict.items()}
# Construct the URL to call the chat API with the scheduled query string
url = f"{scheme}://{parsed_url.netloc}/api/chat?client=khoj"
# Construct the Headers for the chat API
headers = {"User-Agent": "Khoj", "Content-Type": "application/json"}
if not state.anonymous_mode:
# Add authorization request header in non-anonymous mode
token = get_khoj_tokens(user)
if is_none_or_empty(token):
token = create_khoj_token(user).token
else:
token = token[0].token
headers["Authorization"] = f"Bearer {token}"
# Call the chat API endpoint with authenticated user token and query
raw_response = requests.post(url, headers=headers, json=json_payload, allow_redirects=False)
# Handle redirect manually if necessary
if raw_response.status_code in [301, 302, 308]:
redirect_url = raw_response.headers["Location"]
logger.info(f"Redirecting to {redirect_url}")
raw_response = requests.post(redirect_url, headers=headers, json=json_payload)
# Stop if the chat API call was not successful
if raw_response.status_code != 200:
logger.error(f"Failed to run schedule chat: {raw_response.text}, user: {user}, query: {query_to_run}")
return None
# Extract the AI response from the chat API response
cleaned_query = re.sub(r"^/automated_task\s*", "", query_to_run).strip()
is_image = False
if raw_response.headers.get("Content-Type") == "application/json":
response_map = raw_response.json()
ai_response = response_map.get("response") or response_map.get("image")
is_image = False
if isinstance(ai_response, dict):
is_image = ai_response.get("image") is not None
else:
ai_response = raw_response.text
# Notify user if the AI response is satisfactory
if should_notify(
original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user
):
formatted_response = format_automation_response(scheduling_request, cleaned_query, ai_response, user)
if is_resend_enabled():
send_task_email(user.get_short_name(), user.email, cleaned_query, formatted_response, subject, is_image)
else:
return formatted_response
async def create_automation(
q: str,
timezone: str,
user: KhojUser,
calling_url: URL,
chat_history: List[ChatMessageModel] = [],
conversation_id: str = None,
tracer: dict = {},
):
crontime, query_to_run, subject = await aschedule_query(q, chat_history, user, tracer=tracer)
job = await aschedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject
def schedule_automation(
query_to_run: str,
subject: str,
crontime: str,
timezone: str,
scheduling_request: str,
user: KhojUser,
calling_url: URL,
conversation_id: str,
):
# Disable minute level automation recurrence
minute_value = crontime.split(" ")[0]
if not minute_value.isdigit():
# Run automation at some random minute (to distribute request load) instead of running every X minutes
crontime = " ".join([str(math.floor(random() * 60))] + crontime.split(" ")[1:])
# Convert timezone string to timezone object
try:
user_timezone = pytz.timezone(timezone)
except pytz.UnknownTimeZoneError:
logger.warning(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
user_timezone = pytz.utc
trigger = CronTrigger.from_crontab(crontime, user_timezone)
trigger.jitter = 60
# Generate id and metadata used by task scheduler and process locks for the task runs
job_metadata = json.dumps(
{
"query_to_run": query_to_run,
"scheduling_request": scheduling_request,
"subject": subject,
"crontime": crontime,
"conversation_id": str(conversation_id),
}
)
query_id = hashlib.md5(f"{query_to_run}_{crontime}".encode("utf-8")).hexdigest()
job_id = f"automation_{user.uuid}_{query_id}"
job = state.scheduler.add_job(
run_with_process_lock,
trigger=trigger,
args=(
scheduled_chat,
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
),
kwargs={
"query_to_run": query_to_run,
"scheduling_request": scheduling_request,
"subject": subject,
"user": user,
"calling_url": calling_url,
"job_id": job_id,
"conversation_id": conversation_id,
},
id=job_id,
name=job_metadata,
max_instances=2, # Allow second instance to kill any previous instance with stale lock
)
return job
async def aschedule_automation(
query_to_run: str,
subject: str,
crontime: str,
timezone: str,
scheduling_request: str,
user: KhojUser,
calling_url: URL,
conversation_id: str,
):
# Disable minute level automation recurrence
minute_value = crontime.split(" ")[0]
if not minute_value.isdigit():
# Run automation at some random minute (to distribute request load) instead of running every X minutes
crontime = " ".join([str(math.floor(random() * 60))] + crontime.split(" ")[1:])
user_timezone = pytz.timezone(timezone)
trigger = CronTrigger.from_crontab(crontime, user_timezone)
trigger.jitter = 60
# Generate id and metadata used by task scheduler and process locks for the task runs
job_metadata = json.dumps(
{
"query_to_run": query_to_run,
"scheduling_request": scheduling_request,
"subject": subject,
"crontime": crontime,
"conversation_id": str(conversation_id),
}
)
query_id = hashlib.md5(f"{query_to_run}_{crontime}".encode("utf-8")).hexdigest()
job_id = f"automation_{user.uuid}_{query_id}"
job = await sync_to_async(state.scheduler.add_job)(
run_with_process_lock,
trigger=trigger,
args=(
scheduled_chat,
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
),
kwargs={
"query_to_run": query_to_run,
"scheduling_request": scheduling_request,
"subject": subject,
"user": user,
"calling_url": calling_url,
"job_id": job_id,
"conversation_id": conversation_id,
},
id=job_id,
name=job_metadata,
max_instances=2, # Allow second instance to kill any previous instance with stale lock
)
return job
def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str):
# Display next run time in user timezone instead of UTC
schedule = f"{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime('%Z')}"
next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z")
# Remove /automated_task prefix from inferred_query
unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run)
# Create the automation response
automation_icon_url = "/static/assets/icons/automation.svg"
return f"""
### ![]({automation_icon_url}) Created Automation
- Subject: **{subject}**
- Query to Run: "{unprefixed_query_to_run}"
- Schedule: `{schedule}`
- Next Run At: {next_run_time}
Manage your automations [here](/automations).
""".strip()
class MessageProcessor:
def __init__(self):
self.references = {}
self.usage = {}
self.raw_response = ""
self.generated_images = []
self.generated_files = []
self.generated_mermaidjs_diagram = []
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
try:
json_chunk = json.loads(raw_chunk)
if "type" not in json_chunk:
json_chunk = {"type": "message", "data": json_chunk}
return json_chunk
except json.JSONDecodeError:
return {"type": "message", "data": raw_chunk}
elif raw_chunk:
return {"type": "message", "data": raw_chunk}
return {"type": "", "data": ""}
def process_message_chunk(self, raw_chunk: str) -> None:
chunk = self.convert_message_chunk_to_json(raw_chunk)
if not chunk or not chunk["type"]:
return
chunk_type = ChatEvent(chunk["type"])
if chunk_type == ChatEvent.REFERENCES:
self.references = chunk["data"]
elif chunk_type == ChatEvent.USAGE:
self.usage = chunk["data"]
elif chunk_type == ChatEvent.MESSAGE:
chunk_data = chunk["data"]
if isinstance(chunk_data, dict):
self.raw_response = self.handle_json_response(chunk_data)
elif (
isinstance(chunk_data, str) and chunk_data.strip().startswith("{") and chunk_data.strip().endswith("}")
):
try:
json_data = json.loads(chunk_data.strip())
self.raw_response = self.handle_json_response(json_data)
except json.JSONDecodeError:
self.raw_response += chunk_data
else:
self.raw_response += chunk_data
elif chunk_type == ChatEvent.GENERATED_ASSETS:
chunk_data = chunk["data"]
if isinstance(chunk_data, dict):
for key in chunk_data:
if key == "images":
self.generated_images = chunk_data[key]
elif key == "files":
self.generated_files = chunk_data[key]
elif key == "mermaidjsDiagram":
self.generated_mermaidjs_diagram = chunk_data[key]
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
if "image" in json_data or "details" in json_data:
return json_data
if "response" in json_data:
return json_data["response"]
return json_data
async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict[str, Any]:
processor = MessageProcessor()
buffer = ""
async for chunk in response_iterator:
# Start buffering chunks until complete event is received
buffer += chunk
# Once the buffer contains a complete event
while ChatEvent.END_EVENT.value in buffer:
# Extract the event from the buffer
event, buffer = buffer.split(ChatEvent.END_EVENT.value, 1)
# Process the event
if event:
processor.process_message_chunk(event)
# Process any remaining data in the buffer
if buffer:
processor.process_message_chunk(buffer)
return {
"response": processor.raw_response,
"references": processor.references,
"usage": processor.usage,
"images": processor.generated_images,
"files": processor.generated_files,
"mermaidjsDiagram": processor.generated_mermaidjs_diagram,
}
def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]:
"""Get any message in queue if available."""
if not queue:
return None
try:
# Non-blocking check for message in the queue
return queue.get_nowait()
except asyncio.QueueEmpty:
return None
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
user_picture = request.session.get("user", {}).get("picture")
is_active = has_required_scope(request, ["premium"])
has_documents = EntryAdapters.user_has_entries(user=user)
if not is_detailed:
return {
"request": request,
"username": user.username if user else None,
"user_photo": user_picture,
"is_active": is_active,
"has_documents": has_documents,
"khoj_version": state.khoj_version,
}
user_subscription_state = get_user_subscription_state(user.email)
user_subscription = adapters.get_user_subscription(user.email)
subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date
else None
)
subscription_enabled_trial_at = (
user_subscription.enabled_trial_at.strftime("%d %b %Y")
if user_subscription and user_subscription.enabled_trial_at
else None
)
given_name = get_user_name(user)
enabled_content_sources_set = set(EntryAdapters.get_unique_file_sources(user))
enabled_content_sources = {
"computer": ("computer" in enabled_content_sources_set),
"github": ("github" in enabled_content_sources_set),
"notion": ("notion" in enabled_content_sources_set),
}
notion_oauth_url = get_notion_auth_url(user)
current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = ConversationAdapters.get_chat_model(
user
) or ConversationAdapters.get_default_chat_model(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models:
chat_model_options.append(
{
"name": chat_model.friendly_name,
"id": chat_model.id,
"strengths": chat_model.strengths,
"description": chat_model.description,
"tier": chat_model.price_tier,
}
)
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
all_paint_model_options.append(
{
"name": paint_model.friendly_name,
"id": paint_model.id,
"tier": paint_model.price_tier,
}
)
voice_models = ConversationAdapters.get_voice_model_options()
voice_model_options = list()
for voice_model in voice_models:
voice_model_options.append(
{
"name": voice_model.name,
"id": voice_model.model_id,
"tier": voice_model.price_tier,
}
)
if len(voice_model_options) == 0:
eleven_labs_enabled = False
else:
eleven_labs_enabled = is_eleven_labs_enabled()
selected_voice_model_config = ConversationAdapters.get_voice_model_config(user)
return {
"request": request,
# user info
"username": user.username if user else None,
"user_photo": user_picture,
"is_active": is_active,
"given_name": given_name,
"phone_number": str(user.phone_number) if user.phone_number else "",
"is_phone_number_verified": user.verified_phone_number,
# user content settings
"enabled_content_source": enabled_content_sources,
"has_documents": has_documents,
"notion_token": notion_token,
# user model settings
"chat_model_options": chat_model_options,
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
"paint_model_options": all_paint_model_options,
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
"voice_model_options": voice_model_options,
"selected_voice_model_config": selected_voice_model_config.model_id if selected_voice_model_config else None,
# user billing info
"subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date,
"subscription_enabled_trial_at": subscription_enabled_trial_at,
# server settings
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"billing_enabled": state.billing_enabled,
"is_eleven_labs_enabled": eleven_labs_enabled,
"is_twilio_enabled": is_twilio_enabled(),
"khoj_version": state.khoj_version,
"anonymous_mode": state.anonymous_mode,
"notion_oauth_url": notion_oauth_url,
"length_of_free_trial": LENGTH_OF_FREE_TRIAL,
}
def configure_content(
user: KhojUser,
files: Optional[dict[str, dict[str, str]]],
regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All,
) -> bool:
success = True
if t is None:
t = state.SearchType.All
if t is not None and t in [type.value for type in state.SearchType]:
t = state.SearchType(t)
if t is not None and t.value not in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
return False
search_type = t.value if t else None
# Check if client sent any documents of the supported types
no_client_sent_documents = all([not files.get(file_type) for file_type in files])
if files is None:
logger.warning(f"🚨 No files to process for {search_type} search.")
return True
try:
# Initialize Org Notes Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
text_search.setup(
OrgToEntries,
files.get("org"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
success = False
try:
# Initialize Markdown Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Markdown.value) and files[
"markdown"
]:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
text_search.setup(
MarkdownToEntries,
files.get("markdown"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup markdown: {e}", exc_info=True)
success = False
try:
# Initialize PDF Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
text_search.setup(
PdfToEntries,
files.get("pdf"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup PDF: {e}", exc_info=True)
success = False
try:
# Initialize Plaintext Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Plaintext.value) and files[
"plaintext"
]:
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
text_search.setup(
PlaintextToEntries,
files.get("plaintext"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
success = False
try:
# Run server side indexing of user Github docs if no client sent documents
if no_client_sent_documents:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if (
search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
) and github_config is not None:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
text_search.setup(
GithubToEntries,
None,
regenerate=regenerate,
user=user,
config=github_config,
)
except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
success = False
try:
# Run server side indexing of user Notion docs if no client sent documents
if no_client_sent_documents:
# Initialize Notion Search
notion_config = NotionConfig.objects.filter(user=user).first()
if (
search_type == state.SearchType.All.value or search_type == state.SearchType.Notion.value
) and notion_config:
logger.info("🔌 Setting up search for notion")
text_search.setup(
NotionToEntries,
None,
regenerate=regenerate,
user=user,
config=notion_config,
)
except Exception as e:
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
success = False
try:
# Initialize Image Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
"image"
]:
logger.info("🖼️ Setting up search for images")
# Extract Entries, Generate Image Embeddings
text_search.setup(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try:
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]:
logger.info("📄 Setting up search for docx")
text_search.setup(
DocxToEntries,
files.get("docx"),
regenerate=regenerate,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup docx: {e}", exc_info=True)
success = False
# Invalidate Query Cache
if user:
state.query_cache[user.uuid] = LRU()
return success
def get_notion_auth_url(user: KhojUser):
if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI:
return None
return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}"
async def view_file_content(
path: str,
start_line: Optional[int] = None,
end_line: Optional[int] = None,
user: KhojUser = None,
):
"""
View the contents of a file from the user's document database with optional line range specification.
"""
query = f"View file: {path}"
if start_line and end_line:
query += f" (lines {start_line}-{end_line})"
try:
# Get the file object from the database by name
file_objects = await FileObjectAdapters.aget_file_objects_by_name(user, path)
if not file_objects:
error_msg = f"File '{path}' not found in user documents"
logger.warning(error_msg)
yield [{"query": query, "file": path, "compiled": error_msg}]
return
# Use the first file object if multiple exist
file_object = file_objects[0]
raw_text = file_object.raw_text
# Apply line range filtering if specified
if start_line is None and end_line is None:
filtered_text = raw_text
else:
lines = raw_text.split("\n")
start_line = start_line or 1
end_line = end_line or len(lines)
# Validate line range
if start_line < 1 or end_line < 1 or start_line > end_line:
error_msg = f"Invalid line range: {start_line}-{end_line}"
logger.warning(error_msg)
yield [{"query": query, "file": path, "compiled": error_msg}]
return
if start_line > len(lines):
error_msg = f"Start line {start_line} exceeds total number of lines {len(lines)}"
logger.warning(error_msg)
yield [{"query": query, "file": path, "compiled": error_msg}]
return
# Convert from 1-based to 0-based indexing and ensure bounds
start_idx = max(0, start_line - 1)
end_idx = min(len(lines), end_line)
selected_lines = lines[start_idx:end_idx]
filtered_text = "\n".join(selected_lines)
# Truncate the text if it's too long
if len(filtered_text) > 10000:
filtered_text = filtered_text[:10000] + "\n\n[Truncated. Use line numbers to view specific sections.]"
# Format the result as a document reference
document_results = [
{
"query": query,
"file": path,
"uri": path,
"compiled": filtered_text,
}
]
yield document_results
except Exception as e:
error_msg = f"Error viewing file {path}: {str(e)}"
logger.error(error_msg, exc_info=True)
# Return an error result in the expected format
yield [{"query": query, "file": path, "uri": path, "compiled": error_msg}]
async def grep_files(
regex_pattern: str,
path_prefix: Optional[str] = None,
lines_before: Optional[int] = None,
lines_after: Optional[int] = None,
user: KhojUser = None,
):
"""
Search for a regex pattern in files with an optional path prefix and context lines.
"""
# Construct the query string based on provided parameters
def _generate_query(line_count, doc_count, path, pattern, lines_before, lines_after, max_results=1000):
query = f"**Found {line_count} matches for '{pattern}' in {doc_count} documents**"
if path:
query += f" in {path}"
if lines_before or lines_after or line_count > max_results:
query += " Showing"
if lines_before or lines_after:
context_info = []
if lines_before:
context_info.append(f"{lines_before} lines before")
if lines_after:
context_info.append(f"{lines_after} lines after")
query += f" {' and '.join(context_info)}"
if line_count > max_results:
if lines_before or lines_after:
query += " for"
query += f" first {max_results} results"
return query
# Validate regex pattern
path_prefix = path_prefix or ""
lines_before = lines_before or 0
lines_after = lines_after or 0
try:
regex = re.compile(regex_pattern, re.IGNORECASE)
except re.error as e:
yield {
"query": _generate_query(0, 0, path_prefix, regex_pattern, lines_before, lines_after),
"file": path_prefix,
"compiled": f"Invalid regex pattern: {e}",
}
return
try:
file_matches = await FileObjectAdapters.aget_file_objects_by_regex(user, regex_pattern, path_prefix)
line_matches = []
for file_object in file_matches:
lines = file_object.raw_text.split("\n")
matched_line_numbers = []
# Find all matching line numbers first
for i, line in enumerate(lines, 1):
if regex.search(line):
matched_line_numbers.append(i)
# Build context for each match
for line_num in matched_line_numbers:
context_lines = []
# Calculate start and end indices for context (0-based)
start_idx = max(0, line_num - 1 - lines_before)
end_idx = min(len(lines), line_num + lines_after)
# Add context lines with line numbers
for idx in range(start_idx, end_idx):
current_line_num = idx + 1
line_content = lines[idx]
if current_line_num == line_num:
# This is the matching line, mark it
context_lines.append(f"{file_object.file_name}:{current_line_num}:> {line_content}")
else:
# This is a context line
context_lines.append(f"{file_object.file_name}:{current_line_num}: {line_content}")
# Add separator between matches if showing context
if lines_before > 0 or lines_after > 0:
context_lines.append("--")
line_matches.extend(context_lines)
# Remove the last separator if it exists
if line_matches and line_matches[-1] == "--":
line_matches.pop()
# Check if no results found
max_results = 1000
query = _generate_query(
len([m for m in line_matches if ":>" in m]),
len(file_matches),
path_prefix,
regex_pattern,
lines_before,
lines_after,
max_results,
)
if not line_matches:
yield {"query": query, "file": path_prefix, "uri": path_prefix, "compiled": "No matches found."}
return
# Truncate matched lines list if too long
if len(line_matches) > max_results:
line_matches = line_matches[:max_results] + [
f"... {len(line_matches) - max_results} more results found. Use stricter regex or path to narrow down results."
]
yield {"query": query, "file": path_prefix, "uri": path_prefix, "compiled": "\n".join(line_matches)}
except Exception as e:
error_msg = f"Error using grep files tool: {str(e)}"
logger.error(error_msg, exc_info=True)
yield [
{
"query": _generate_query(0, 0, path_prefix or "", regex_pattern, lines_before, lines_after),
"file": path_prefix,
"uri": path_prefix,
"compiled": error_msg,
}
]
async def list_files(
path: Optional[str] = None,
pattern: Optional[str] = None,
user: KhojUser = None,
):
"""
List files under a given path or glob pattern from the user's document database.
"""
# Construct the query string based on provided parameters
def _generate_query(doc_count, path, pattern):
query = f"**Found {doc_count} files**"
if path:
query += f" in {path}"
if pattern:
query += f" filtered by {pattern}"
return query
try:
# Get user files by path prefix when specified
path = path or ""
if path in ["", "/", ".", "./", "~", "~/"]:
file_objects = await FileObjectAdapters.aget_all_file_objects(user, limit=10000)
else:
file_objects = await FileObjectAdapters.aget_file_objects_by_path_prefix(user, path)
if not file_objects:
yield {"query": _generate_query(0, path, pattern), "file": path, "uri": path, "compiled": "No files found."}
return
# Extract file names from file objects
files = [f.file_name for f in file_objects]
# Convert to relative file path (similar to ls)
if path:
files = [f[len(path) :] for f in files]
# Apply glob pattern filtering if specified
if pattern:
files = [f for f in files if fnmatch.fnmatch(f, pattern)]
query = _generate_query(len(files), path, pattern)
if not files:
yield {"query": query, "file": path, "uri": path, "compiled": "No files found."}
return
# Truncate the list if it's too long
max_files = 100
if len(files) > max_files:
files = files[:max_files] + [
f"... {len(files) - max_files} more files found. Use glob pattern to narrow down results."
]
yield {"query": query, "file": path, "uri": path, "compiled": "\n- ".join(files)}
except Exception as e:
error_msg = f"Error listing files in {path}: {str(e)}"
logger.error(error_msg, exc_info=True)
yield {"query": query, "file": path, "uri": path, "compiled": error_msg}