mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 05:39:11 +00:00
2548 lines
95 KiB
Python
2548 lines
95 KiB
Python
import base64
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import re
|
|
from datetime import datetime, timedelta, timezone
|
|
from functools import partial
|
|
from random import random
|
|
from typing import (
|
|
Annotated,
|
|
Any,
|
|
AsyncGenerator,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
|
|
|
|
import cron_descriptor
|
|
import pyjson5
|
|
import pytz
|
|
import requests
|
|
from apscheduler.job import Job
|
|
from apscheduler.triggers.cron import CronTrigger
|
|
from asgiref.sync import sync_to_async
|
|
from django.utils import timezone as django_timezone
|
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
|
from pydantic import BaseModel, EmailStr, Field
|
|
from starlette.authentication import has_required_scope
|
|
from starlette.requests import URL
|
|
|
|
from khoj.database import adapters
|
|
from khoj.database.adapters import (
|
|
LENGTH_OF_FREE_TRIAL,
|
|
AgentAdapters,
|
|
AutomationAdapters,
|
|
ConversationAdapters,
|
|
EntryAdapters,
|
|
FileObjectAdapters,
|
|
aget_user_by_email,
|
|
ais_user_subscribed,
|
|
create_khoj_token,
|
|
get_khoj_tokens,
|
|
get_user_name,
|
|
get_user_notion_config,
|
|
get_user_subscription_state,
|
|
is_user_subscribed,
|
|
run_with_process_lock,
|
|
)
|
|
from khoj.database.models import (
|
|
Agent,
|
|
ChatModel,
|
|
ClientApplication,
|
|
Conversation,
|
|
GithubConfig,
|
|
KhojUser,
|
|
NotionConfig,
|
|
ProcessLock,
|
|
RateLimitRecord,
|
|
Subscription,
|
|
TextToImageModelConfig,
|
|
UserRequests,
|
|
)
|
|
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
|
|
from khoj.processor.content.github.github_to_entries import GithubToEntries
|
|
from khoj.processor.content.images.image_to_entries import ImageToEntries
|
|
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
|
|
from khoj.processor.content.notion.notion_to_entries import NotionToEntries
|
|
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
|
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
|
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
|
from khoj.processor.conversation import prompts
|
|
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
|
anthropic_send_message_to_model,
|
|
converse_anthropic,
|
|
)
|
|
from khoj.processor.conversation.google.gemini_chat import (
|
|
converse_gemini,
|
|
gemini_send_message_to_model,
|
|
)
|
|
from khoj.processor.conversation.offline.chat_model import (
|
|
converse_offline,
|
|
send_message_to_model_offline,
|
|
)
|
|
from khoj.processor.conversation.openai.gpt import (
|
|
converse_openai,
|
|
send_message_to_model,
|
|
)
|
|
from khoj.processor.conversation.utils import (
|
|
ChatEvent,
|
|
InformationCollectionIteration,
|
|
ResponseWithThought,
|
|
clean_json,
|
|
clean_mermaidjs,
|
|
construct_chat_history,
|
|
generate_chatml_messages_with_context,
|
|
save_to_conversation_log,
|
|
)
|
|
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
|
|
from khoj.routers.email import is_resend_enabled, send_task_email
|
|
from khoj.routers.twilio import is_twilio_enabled
|
|
from khoj.search_type import text_search
|
|
from khoj.utils import state
|
|
from khoj.utils.config import OfflineChatProcessorModel
|
|
from khoj.utils.helpers import (
|
|
LRU,
|
|
ConversationCommand,
|
|
get_file_type,
|
|
in_debug_mode,
|
|
is_none_or_empty,
|
|
is_operator_enabled,
|
|
is_valid_url,
|
|
log_telemetry,
|
|
mode_descriptions_for_llm,
|
|
timer,
|
|
tool_descriptions_for_llm,
|
|
)
|
|
from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
|
|
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
|
|
NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI")
|
|
|
|
|
|
def is_query_empty(query: str) -> bool:
|
|
return is_none_or_empty(query.strip())
|
|
|
|
|
|
def validate_chat_model(user: KhojUser):
|
|
default_chat_model = ConversationAdapters.get_default_chat_model(user)
|
|
|
|
if default_chat_model is None:
|
|
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
|
|
|
if default_chat_model.model_type == "openai" and not default_chat_model.ai_model_api:
|
|
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
|
|
|
|
|
async def is_ready_to_chat(user: KhojUser):
|
|
user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
|
|
if user_chat_model == None:
|
|
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
|
|
|
if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
|
chat_model_name = user_chat_model.name
|
|
max_tokens = user_chat_model.max_prompt_size
|
|
if state.offline_chat_processor_config is None:
|
|
logger.info("Loading Offline Chat Model...")
|
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
|
return True
|
|
|
|
if (
|
|
user_chat_model
|
|
and (
|
|
user_chat_model.model_type
|
|
in [
|
|
ChatModel.ModelType.OPENAI,
|
|
ChatModel.ModelType.ANTHROPIC,
|
|
ChatModel.ModelType.GOOGLE,
|
|
]
|
|
)
|
|
and user_chat_model.ai_model_api
|
|
):
|
|
return True
|
|
|
|
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
|
|
|
|
|
def get_file_content(file: UploadFile):
|
|
file_content = file.file.read()
|
|
file_type, encoding = get_file_type(file.content_type, file_content)
|
|
return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding)
|
|
|
|
|
|
def update_telemetry_state(
|
|
request: Request,
|
|
telemetry_type: str,
|
|
api: str,
|
|
client: Optional[str] = None,
|
|
user_agent: Optional[str] = None,
|
|
referer: Optional[str] = None,
|
|
host: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
):
|
|
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
|
client_app: ClientApplication = request.user.client_app if request.user.is_authenticated else None
|
|
subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None
|
|
user_state = {
|
|
"client_host": request.client.host if request.client else None,
|
|
"user_agent": user_agent or "unknown",
|
|
"referer": referer or "unknown",
|
|
"host": host or "unknown",
|
|
"server_id": str(user.uuid) if user else None,
|
|
"subscription_type": subscription.type if subscription else None,
|
|
"is_recurring": subscription.is_recurring if subscription else None,
|
|
"client_id": str(client_app.name) if client_app else "default",
|
|
}
|
|
|
|
if metadata:
|
|
user_state.update(metadata)
|
|
|
|
state.telemetry += [
|
|
log_telemetry(
|
|
telemetry_type=telemetry_type,
|
|
api=api,
|
|
client=client,
|
|
app_config=state.config.app,
|
|
disable_telemetry_env=state.telemetry_disabled,
|
|
properties=user_state,
|
|
)
|
|
]
|
|
|
|
|
|
def get_next_url(request: Request) -> str:
|
|
"Construct next url relative to current domain from request"
|
|
next_url_param = urlparse(request.query_params.get("next", "/"))
|
|
next_path = "/" # default next path
|
|
# If relative path or absolute path to current domain
|
|
if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc:
|
|
# Use path in next query param
|
|
next_path = next_url_param.path
|
|
# Construct absolute url using current domain and next path from request
|
|
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
|
|
|
|
|
def get_conversation_command(query: str) -> ConversationCommand:
|
|
if query.startswith("/notes"):
|
|
return ConversationCommand.Notes
|
|
elif query.startswith("/help"):
|
|
return ConversationCommand.Help
|
|
elif query.startswith("/general"):
|
|
return ConversationCommand.General
|
|
elif query.startswith("/online"):
|
|
return ConversationCommand.Online
|
|
elif query.startswith("/webpage"):
|
|
return ConversationCommand.Webpage
|
|
elif query.startswith("/image"):
|
|
return ConversationCommand.Image
|
|
elif query.startswith("/automated_task"):
|
|
return ConversationCommand.AutomatedTask
|
|
elif query.startswith("/summarize"):
|
|
return ConversationCommand.Summarize
|
|
elif query.startswith("/diagram"):
|
|
return ConversationCommand.Diagram
|
|
elif query.startswith("/code"):
|
|
return ConversationCommand.Code
|
|
elif query.startswith("/research"):
|
|
return ConversationCommand.Research
|
|
elif query.startswith("/operator") and is_operator_enabled():
|
|
return ConversationCommand.Operator
|
|
else:
|
|
return ConversationCommand.Default
|
|
|
|
|
|
def gather_raw_query_files(
|
|
query_files: Dict[str, str],
|
|
):
|
|
"""
|
|
Gather contextual data from the given (raw) files
|
|
"""
|
|
|
|
if len(query_files) == 0:
|
|
return ""
|
|
|
|
contextual_data = " ".join(
|
|
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
|
|
)
|
|
return f"I have attached the following files:\n\n{contextual_data}"
|
|
|
|
|
|
async def acreate_title_from_history(
|
|
user: KhojUser,
|
|
conversation: Conversation,
|
|
):
|
|
"""
|
|
Create a title from the given conversation history
|
|
"""
|
|
chat_history = construct_chat_history(conversation.conversation_log)
|
|
|
|
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
|
|
|
|
with timer("Chat actor: Generate title from conversation history", logger):
|
|
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
|
|
|
return response.strip()
|
|
|
|
|
|
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
|
"""
|
|
Create a title from the given query
|
|
"""
|
|
title_generation_prompt = prompts.subject_generation.format(query=query)
|
|
|
|
with timer("Chat actor: Generate title from query", logger):
|
|
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
|
|
|
return response.strip()
|
|
|
|
|
|
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]:
|
|
"""
|
|
Check if the system prompt is safe to use
|
|
"""
|
|
safe_prompt_check = (
|
|
prompts.personality_prompt_safety_expert.format(prompt=system_prompt)
|
|
if not lax
|
|
else prompts.personality_prompt_safety_expert_lax.format(prompt=system_prompt)
|
|
)
|
|
is_safe = True
|
|
reason = ""
|
|
|
|
class SafetyCheck(BaseModel):
|
|
safe: bool
|
|
reason: str
|
|
|
|
with timer("Chat actor: Check if safe prompt", logger):
|
|
response = await send_message_to_model_wrapper(
|
|
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
|
|
)
|
|
|
|
response = response.strip()
|
|
try:
|
|
response = json.loads(clean_json(response))
|
|
is_safe = str(response.get("safe", "true")).lower() == "true"
|
|
if not is_safe:
|
|
reason = response.get("reason", "")
|
|
except Exception:
|
|
logger.error(f"Invalid response for checking safe prompt: {response}")
|
|
|
|
if not is_safe:
|
|
logger.error(f"Unsafe prompt: {system_prompt}. Reason: {reason}")
|
|
|
|
return is_safe, reason
|
|
|
|
|
|
async def aget_data_sources_and_output_format(
|
|
query: str,
|
|
conversation_history: dict,
|
|
is_task: bool,
|
|
user: KhojUser,
|
|
query_images: List[str] = None,
|
|
agent: Agent = None,
|
|
query_files: str = None,
|
|
tracer: dict = {},
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Given a query, determine which of the available data sources and output modes the agent should use to answer appropriately.
|
|
"""
|
|
|
|
source_options = dict()
|
|
source_options_str = ""
|
|
|
|
agent_sources = agent.input_tools if agent else []
|
|
user_has_entries = await EntryAdapters.auser_has_entries(user)
|
|
|
|
for source, description in tool_descriptions_for_llm.items():
|
|
# Skip showing Notes tool as an option if user has no entries
|
|
if source == ConversationCommand.Notes and not user_has_entries:
|
|
continue
|
|
if source == ConversationCommand.Operator and not is_operator_enabled():
|
|
continue
|
|
source_options[source.value] = description
|
|
if len(agent_sources) == 0 or source.value in agent_sources:
|
|
source_options_str += f'- "{source.value}": "{description}"\n'
|
|
|
|
output_options = dict()
|
|
output_options_str = ""
|
|
|
|
agent_outputs = agent.output_modes if agent else []
|
|
|
|
for output, description in mode_descriptions_for_llm.items():
|
|
# Do not allow tasks to schedule another task
|
|
if is_task and output == ConversationCommand.Automation:
|
|
continue
|
|
output_options[output.value] = description
|
|
if len(agent_outputs) == 0 or output.value in agent_outputs:
|
|
output_options_str += f'- "{output.value}": "{description}"\n'
|
|
|
|
chat_history = construct_chat_history(conversation_history)
|
|
|
|
if query_images:
|
|
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
|
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
relevant_tools_prompt = prompts.pick_relevant_tools.format(
|
|
query=query,
|
|
sources=source_options_str,
|
|
outputs=output_options_str,
|
|
chat_history=chat_history,
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
class PickTools(BaseModel):
|
|
source: List[str] = Field(..., min_items=1)
|
|
output: str
|
|
|
|
with timer("Chat actor: Infer information sources to refer", logger):
|
|
response = await send_message_to_model_wrapper(
|
|
relevant_tools_prompt,
|
|
response_type="json_object",
|
|
response_schema=PickTools,
|
|
user=user,
|
|
query_files=query_files,
|
|
agent_chat_model=agent_chat_model,
|
|
tracer=tracer,
|
|
)
|
|
|
|
try:
|
|
response = clean_json(response)
|
|
response = json.loads(response)
|
|
|
|
chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()]
|
|
chosen_output = response.get("output", "text").strip() # Default to text output
|
|
|
|
if is_none_or_empty(chosen_sources) or not isinstance(chosen_sources, list):
|
|
raise ValueError(
|
|
f"Invalid response for determining relevant tools: {chosen_sources}. Raw Response: {response}"
|
|
)
|
|
|
|
output_mode = ConversationCommand.Text
|
|
# Verify selected output mode is enabled for the agent, as the LLM can sometimes get confused by the tool options.
|
|
if chosen_output in output_options.keys() and (len(agent_outputs) == 0 or chosen_output in agent_outputs):
|
|
# Ensure that the chosen output mode exists as a valid ConversationCommand
|
|
output_mode = ConversationCommand(chosen_output)
|
|
|
|
data_sources = []
|
|
# Verify selected data sources are enabled for the agent, as the LLM can sometimes get confused by the tool options.
|
|
for chosen_source in chosen_sources:
|
|
# Ensure that the chosen data source exists as a valid ConversationCommand
|
|
if chosen_source in source_options.keys() and (len(agent_sources) == 0 or chosen_source in agent_sources):
|
|
data_sources.append(ConversationCommand(chosen_source))
|
|
|
|
# Fallback to default sources if the inferred data sources are unset or invalid
|
|
if is_none_or_empty(data_sources):
|
|
if len(agent_sources) == 0:
|
|
data_sources = [ConversationCommand.Default]
|
|
else:
|
|
data_sources = [ConversationCommand.General]
|
|
except Exception as e:
|
|
logger.error(f"Invalid response for determining relevant tools: {response}. Error: {e}", exc_info=True)
|
|
data_sources = agent_sources if len(agent_sources) > 0 else [ConversationCommand.Default]
|
|
output_mode = agent_outputs[0] if len(agent_outputs) > 0 else ConversationCommand.Text
|
|
|
|
return {"sources": data_sources, "output": output_mode}
|
|
|
|
|
|
async def infer_webpage_urls(
|
|
q: str,
|
|
max_webpages: int,
|
|
conversation_history: dict,
|
|
location_data: LocationData,
|
|
user: KhojUser,
|
|
query_images: List[str] = None,
|
|
agent: Agent = None,
|
|
query_files: str = None,
|
|
tracer: dict = {},
|
|
) -> List[str]:
|
|
"""
|
|
Infer webpage links from the given query
|
|
"""
|
|
location = f"{location_data}" if location_data else "Unknown"
|
|
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
|
chat_history = construct_chat_history(conversation_history)
|
|
|
|
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
online_queries_prompt = prompts.infer_webpages_to_read.format(
|
|
query=q,
|
|
max_webpages=max_webpages,
|
|
chat_history=chat_history,
|
|
current_date=utc_date,
|
|
location=location,
|
|
username=username,
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
class WebpageUrls(BaseModel):
|
|
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
|
|
|
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
|
response = await send_message_to_model_wrapper(
|
|
online_queries_prompt,
|
|
query_images=query_images,
|
|
response_type="json_object",
|
|
response_schema=WebpageUrls,
|
|
user=user,
|
|
query_files=query_files,
|
|
agent_chat_model=agent_chat_model,
|
|
tracer=tracer,
|
|
)
|
|
|
|
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
|
try:
|
|
response = clean_json(response)
|
|
urls = json.loads(response)
|
|
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
|
|
if is_none_or_empty(valid_unique_urls):
|
|
raise ValueError(f"Invalid list of urls: {response}")
|
|
if len(valid_unique_urls) == 0:
|
|
logger.error(f"No valid URLs found in response: {response}")
|
|
return []
|
|
return list(valid_unique_urls)[:max_webpages]
|
|
except Exception:
|
|
raise ValueError(f"Invalid list of urls: {response}")
|
|
|
|
|
|
async def generate_online_subqueries(
|
|
q: str,
|
|
conversation_history: dict,
|
|
location_data: LocationData,
|
|
user: KhojUser,
|
|
query_images: List[str] = None,
|
|
query_files: str = None,
|
|
max_queries: int = 3,
|
|
agent: Agent = None,
|
|
tracer: dict = {},
|
|
) -> Set[str]:
|
|
"""
|
|
Generate subqueries from the given query
|
|
"""
|
|
location = f"{location_data}" if location_data else "Unknown"
|
|
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
|
chat_history = construct_chat_history(conversation_history)
|
|
|
|
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
|
query=q,
|
|
chat_history=chat_history,
|
|
max_queries=max_queries,
|
|
current_date=utc_date,
|
|
location=location,
|
|
username=username,
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
class OnlineQueries(BaseModel):
|
|
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
|
|
|
|
with timer("Chat actor: Generate online search subqueries", logger):
|
|
response = await send_message_to_model_wrapper(
|
|
online_queries_prompt,
|
|
query_images=query_images,
|
|
response_type="json_object",
|
|
response_schema=OnlineQueries,
|
|
user=user,
|
|
query_files=query_files,
|
|
agent_chat_model=agent_chat_model,
|
|
tracer=tracer,
|
|
)
|
|
|
|
# Validate that the response is a non-empty, JSON-serializable list
|
|
try:
|
|
response = clean_json(response)
|
|
response = pyjson5.loads(response)
|
|
response = {q.strip() for q in response["queries"] if q.strip()}
|
|
if not isinstance(response, set) or not response or len(response) == 0:
|
|
logger.error(
|
|
f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}"
|
|
)
|
|
return {q}
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
|
|
return {q}
|
|
|
|
|
|
def schedule_query(
|
|
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
|
) -> Tuple[str, str, str]:
|
|
"""
|
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
|
"""
|
|
chat_history = construct_chat_history(conversation_history)
|
|
|
|
crontime_prompt = prompts.crontime_prompt.format(
|
|
query=q,
|
|
chat_history=chat_history,
|
|
)
|
|
|
|
raw_response = send_message_to_model_wrapper_sync(
|
|
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
|
|
)
|
|
|
|
# Validate that the response is a non-empty, JSON-serializable list
|
|
try:
|
|
raw_response = raw_response.strip()
|
|
response: Dict[str, str] = json.loads(clean_json(raw_response))
|
|
if not response or not isinstance(response, Dict) or len(response) != 3:
|
|
raise AssertionError(f"Invalid response for scheduling query : {response}")
|
|
return response.get("crontime"), response.get("query"), response.get("subject")
|
|
except Exception:
|
|
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
|
|
|
|
|
async def aschedule_query(
|
|
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
|
) -> Tuple[str, str, str]:
|
|
"""
|
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
|
"""
|
|
chat_history = construct_chat_history(conversation_history)
|
|
|
|
crontime_prompt = prompts.crontime_prompt.format(
|
|
query=q,
|
|
chat_history=chat_history,
|
|
)
|
|
|
|
raw_response = await send_message_to_model_wrapper(
|
|
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
|
|
)
|
|
|
|
# Validate that the response is a non-empty, JSON-serializable list
|
|
try:
|
|
raw_response = raw_response.strip()
|
|
response: Dict[str, str] = json.loads(clean_json(raw_response))
|
|
if not response or not isinstance(response, Dict) or len(response) != 3:
|
|
raise AssertionError(f"Invalid response for scheduling query : {response}")
|
|
return response.get("crontime"), response.get("query"), response.get("subject")
|
|
except Exception:
|
|
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
|
|
|
|
|
async def extract_relevant_info(
|
|
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
|
|
) -> Union[str, None]:
|
|
"""
|
|
Extract relevant information for a given query from the target corpus
|
|
"""
|
|
|
|
if is_none_or_empty(corpus) or is_none_or_empty(qs):
|
|
return None
|
|
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
extract_relevant_information = prompts.extract_relevant_information.format(
|
|
query=", ".join(qs),
|
|
corpus=corpus.strip(),
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
response = await send_message_to_model_wrapper(
|
|
extract_relevant_information,
|
|
prompts.system_prompt_extract_relevant_information,
|
|
user=user,
|
|
agent_chat_model=agent_chat_model,
|
|
tracer=tracer,
|
|
)
|
|
return response.strip()
|
|
|
|
|
|
async def extract_relevant_summary(
|
|
q: str,
|
|
corpus: str,
|
|
conversation_history: dict,
|
|
query_images: List[str] = None,
|
|
user: KhojUser = None,
|
|
agent: Agent = None,
|
|
tracer: dict = {},
|
|
) -> Union[str, None]:
|
|
"""
|
|
Extract relevant information for a given query from the target corpus
|
|
"""
|
|
|
|
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
|
return None
|
|
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
chat_history = construct_chat_history(conversation_history)
|
|
|
|
extract_relevant_information = prompts.extract_relevant_summary.format(
|
|
query=q,
|
|
chat_history=chat_history,
|
|
corpus=corpus.strip(),
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
with timer("Chat actor: Extract relevant information from data", logger):
|
|
response = await send_message_to_model_wrapper(
|
|
extract_relevant_information,
|
|
prompts.system_prompt_extract_relevant_summary,
|
|
user=user,
|
|
query_images=query_images,
|
|
agent_chat_model=agent_chat_model,
|
|
tracer=tracer,
|
|
)
|
|
return response.strip()
|
|
|
|
|
|
async def generate_summary_from_files(
|
|
q: str,
|
|
user: KhojUser,
|
|
file_filters: List[str],
|
|
meta_log: dict,
|
|
query_images: List[str] = None,
|
|
agent: Agent = None,
|
|
send_status_func: Optional[Callable] = None,
|
|
query_files: str = None,
|
|
tracer: dict = {},
|
|
):
|
|
try:
|
|
file_objects = None
|
|
if await EntryAdapters.aagent_has_entries(agent):
|
|
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
|
if len(file_names) > 0:
|
|
file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent)
|
|
|
|
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
|
|
response_log = "Sorry, I couldn't find anything to summarize."
|
|
yield response_log
|
|
return
|
|
|
|
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
|
|
|
|
if query_files:
|
|
contextual_data += f"\n\n{query_files}"
|
|
|
|
if not q:
|
|
q = "Create a general summary of the file"
|
|
|
|
file_names = [file.file_name for file in file_objects]
|
|
file_names.extend(file_filters)
|
|
|
|
all_file_names = ""
|
|
|
|
for file_name in file_names:
|
|
all_file_names += f"- {file_name}\n"
|
|
|
|
async for result in send_status_func(f"**Constructing Summary Using:**\n{all_file_names}"):
|
|
yield {ChatEvent.STATUS: result}
|
|
|
|
response = await extract_relevant_summary(
|
|
q,
|
|
contextual_data,
|
|
conversation_history=meta_log,
|
|
query_images=query_images,
|
|
user=user,
|
|
agent=agent,
|
|
tracer=tracer,
|
|
)
|
|
|
|
yield str(response)
|
|
except Exception as e:
|
|
response_log = "Error summarizing file. Please try again, or contact support."
|
|
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
|
yield result
|
|
|
|
|
|
async def generate_excalidraw_diagram(
|
|
q: str,
|
|
conversation_history: Dict[str, Any],
|
|
location_data: LocationData,
|
|
note_references: List[Dict[str, Any]],
|
|
online_results: Optional[dict] = None,
|
|
query_images: List[str] = None,
|
|
user: KhojUser = None,
|
|
agent: Agent = None,
|
|
send_status_func: Optional[Callable] = None,
|
|
query_files: str = None,
|
|
tracer: dict = {},
|
|
):
|
|
if send_status_func:
|
|
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
|
yield {ChatEvent.STATUS: event}
|
|
|
|
better_diagram_description_prompt = await generate_better_diagram_description(
|
|
q=q,
|
|
conversation_history=conversation_history,
|
|
location_data=location_data,
|
|
note_references=note_references,
|
|
online_results=online_results,
|
|
query_images=query_images,
|
|
user=user,
|
|
agent=agent,
|
|
query_files=query_files,
|
|
tracer=tracer,
|
|
)
|
|
|
|
if send_status_func:
|
|
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
|
|
yield {ChatEvent.STATUS: event}
|
|
try:
|
|
excalidraw_diagram_description = await generate_excalidraw_diagram_from_description(
|
|
q=better_diagram_description_prompt,
|
|
user=user,
|
|
agent=agent,
|
|
tracer=tracer,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error generating Excalidraw diagram for {user.email}: {e}", exc_info=True)
|
|
yield better_diagram_description_prompt, None
|
|
return
|
|
|
|
scratchpad = excalidraw_diagram_description.get("scratchpad")
|
|
|
|
inferred_queries = f"Instruction: {better_diagram_description_prompt}\n\nScratchpad: {scratchpad}"
|
|
|
|
yield inferred_queries, excalidraw_diagram_description.get("elements")
|
|
|
|
|
|
async def generate_better_diagram_description(
|
|
q: str,
|
|
conversation_history: Dict[str, Any],
|
|
location_data: LocationData,
|
|
note_references: List[Dict[str, Any]],
|
|
online_results: Optional[dict] = None,
|
|
query_images: List[str] = None,
|
|
user: KhojUser = None,
|
|
agent: Agent = None,
|
|
query_files: str = None,
|
|
tracer: dict = {},
|
|
) -> str:
|
|
"""
|
|
Generate a diagram description from the given query and context
|
|
"""
|
|
|
|
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
location = f"{location_data}" if location_data else "Unknown"
|
|
|
|
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
|
|
|
chat_history = construct_chat_history(conversation_history)
|
|
|
|
simplified_online_results = {}
|
|
|
|
if online_results:
|
|
for result in online_results:
|
|
if online_results[result].get("answerBox"):
|
|
simplified_online_results[result] = online_results[result]["answerBox"]
|
|
elif online_results[result].get("webpages"):
|
|
simplified_online_results[result] = online_results[result]["webpages"]
|
|
|
|
improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format(
|
|
query=q,
|
|
chat_history=chat_history,
|
|
location=location,
|
|
current_date=today_date,
|
|
references=user_references,
|
|
online_results=simplified_online_results,
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
with timer("Chat actor: Generate better diagram description", logger):
|
|
response = await send_message_to_model_wrapper(
|
|
improve_diagram_description_prompt,
|
|
query_images=query_images,
|
|
user=user,
|
|
query_files=query_files,
|
|
agent_chat_model=agent_chat_model,
|
|
tracer=tracer,
|
|
)
|
|
response = response.strip()
|
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
|
response = response[1:-1]
|
|
|
|
return response
|
|
|
|
|
|
async def generate_excalidraw_diagram_from_description(
|
|
q: str,
|
|
user: KhojUser = None,
|
|
agent: Agent = None,
|
|
tracer: dict = {},
|
|
) -> Dict[str, Any]:
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
excalidraw_diagram_generation = prompts.excalidraw_diagram_generation_prompt.format(
|
|
personality_context=personality_context,
|
|
query=q,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
with timer("Chat actor: Generate excalidraw diagram", logger):
|
|
raw_response = await send_message_to_model_wrapper(
|
|
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
|
)
|
|
raw_response = clean_json(raw_response)
|
|
try:
|
|
# Expect response to have `elements` and `scratchpad` keys
|
|
response: Dict[str, str] = json.loads(raw_response)
|
|
if (
|
|
not response
|
|
or not isinstance(response, Dict)
|
|
or not response.get("elements")
|
|
or not response.get("scratchpad")
|
|
):
|
|
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {response}")
|
|
except Exception:
|
|
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response}")
|
|
if not response or not isinstance(response["elements"], List) or not isinstance(response["elements"][0], Dict):
|
|
# TODO Some additional validation here that it's a valid Excalidraw diagram
|
|
raise AssertionError(f"Invalid response for improving diagram description: {response}")
|
|
|
|
return response
|
|
|
|
|
|
async def generate_mermaidjs_diagram(
|
|
q: str,
|
|
conversation_history: Dict[str, Any],
|
|
location_data: LocationData,
|
|
note_references: List[Dict[str, Any]],
|
|
online_results: Optional[dict] = None,
|
|
query_images: List[str] = None,
|
|
user: KhojUser = None,
|
|
agent: Agent = None,
|
|
send_status_func: Optional[Callable] = None,
|
|
query_files: str = None,
|
|
tracer: dict = {},
|
|
):
|
|
if send_status_func:
|
|
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
|
yield {ChatEvent.STATUS: event}
|
|
|
|
better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description(
|
|
q=q,
|
|
conversation_history=conversation_history,
|
|
location_data=location_data,
|
|
note_references=note_references,
|
|
online_results=online_results,
|
|
query_images=query_images,
|
|
user=user,
|
|
agent=agent,
|
|
query_files=query_files,
|
|
tracer=tracer,
|
|
)
|
|
|
|
if send_status_func:
|
|
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
|
|
yield {ChatEvent.STATUS: event}
|
|
|
|
mermaidjs_diagram_description = await generate_mermaidjs_diagram_from_description(
|
|
q=better_diagram_description_prompt,
|
|
user=user,
|
|
agent=agent,
|
|
tracer=tracer,
|
|
)
|
|
|
|
inferred_queries = f"Instruction: {better_diagram_description_prompt}"
|
|
|
|
yield inferred_queries, mermaidjs_diagram_description
|
|
|
|
|
|
async def generate_better_mermaidjs_diagram_description(
|
|
q: str,
|
|
conversation_history: Dict[str, Any],
|
|
location_data: LocationData,
|
|
note_references: List[Dict[str, Any]],
|
|
online_results: Optional[dict] = None,
|
|
query_images: List[str] = None,
|
|
user: KhojUser = None,
|
|
agent: Agent = None,
|
|
query_files: str = None,
|
|
tracer: dict = {},
|
|
) -> str:
|
|
"""
|
|
Generate a diagram description from the given query and context
|
|
"""
|
|
|
|
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
location = f"{location_data}" if location_data else "Unknown"
|
|
|
|
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
|
|
|
chat_history = construct_chat_history(conversation_history)
|
|
|
|
simplified_online_results = {}
|
|
|
|
if online_results:
|
|
for result in online_results:
|
|
if online_results[result].get("answerBox"):
|
|
simplified_online_results[result] = online_results[result]["answerBox"]
|
|
elif online_results[result].get("webpages"):
|
|
simplified_online_results[result] = online_results[result]["webpages"]
|
|
|
|
improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format(
|
|
query=q,
|
|
chat_history=chat_history,
|
|
location=location,
|
|
current_date=today_date,
|
|
references=user_references,
|
|
online_results=simplified_online_results,
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
|
response = await send_message_to_model_wrapper(
|
|
improve_diagram_description_prompt,
|
|
query_images=query_images,
|
|
user=user,
|
|
query_files=query_files,
|
|
agent_chat_model=agent_chat_model,
|
|
tracer=tracer,
|
|
)
|
|
response = response.strip()
|
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
|
response = response[1:-1]
|
|
|
|
return response
|
|
|
|
|
|
async def generate_mermaidjs_diagram_from_description(
|
|
q: str,
|
|
user: KhojUser = None,
|
|
agent: Agent = None,
|
|
tracer: dict = {},
|
|
) -> str:
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
|
|
mermaidjs_diagram_generation = prompts.mermaid_js_diagram_generation_prompt.format(
|
|
personality_context=personality_context,
|
|
query=q,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
|
raw_response = await send_message_to_model_wrapper(
|
|
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
|
)
|
|
return clean_mermaidjs(raw_response.strip())
|
|
|
|
|
|
async def generate_better_image_prompt(
|
|
q: str,
|
|
conversation_history: str,
|
|
location_data: LocationData,
|
|
note_references: List[Dict[str, Any]],
|
|
online_results: Optional[dict] = None,
|
|
model_type: Optional[str] = None,
|
|
query_images: Optional[List[str]] = None,
|
|
user: KhojUser = None,
|
|
agent: Agent = None,
|
|
query_files: str = "",
|
|
tracer: dict = {},
|
|
) -> str:
|
|
"""
|
|
Generate a better image prompt from the given query
|
|
"""
|
|
|
|
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
|
personality_context = (
|
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
)
|
|
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
|
|
|
|
location = f"{location_data}" if location_data else "Unknown"
|
|
|
|
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
|
|
|
simplified_online_results = {}
|
|
|
|
if online_results:
|
|
for result in online_results:
|
|
if online_results[result].get("answerBox"):
|
|
simplified_online_results[result] = online_results[result]["answerBox"]
|
|
elif online_results[result].get("webpages"):
|
|
simplified_online_results[result] = online_results[result]["webpages"]
|
|
|
|
if model_type == TextToImageModelConfig.ModelType.OPENAI:
|
|
image_prompt = prompts.image_generation_improve_prompt_dalle.format(
|
|
query=q,
|
|
chat_history=conversation_history,
|
|
location=location,
|
|
current_date=today_date,
|
|
references=user_references,
|
|
online_results=simplified_online_results,
|
|
personality_context=personality_context,
|
|
)
|
|
elif model_type in [
|
|
TextToImageModelConfig.ModelType.STABILITYAI,
|
|
TextToImageModelConfig.ModelType.REPLICATE,
|
|
TextToImageModelConfig.ModelType.GOOGLE,
|
|
]:
|
|
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
|
query=q,
|
|
chat_history=conversation_history,
|
|
location=location,
|
|
current_date=today_date,
|
|
references=user_references,
|
|
online_results=simplified_online_results,
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
|
|
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
|
response = await send_message_to_model_wrapper(
|
|
image_prompt,
|
|
query_images=query_images,
|
|
user=user,
|
|
query_files=query_files,
|
|
agent_chat_model=agent_chat_model,
|
|
tracer=tracer,
|
|
)
|
|
response = response.strip()
|
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
|
response = response[1:-1]
|
|
|
|
return response
|
|
|
|
|
|
async def send_message_to_model_wrapper(
|
|
query: str,
|
|
system_message: str = "",
|
|
response_type: str = "text",
|
|
response_schema: BaseModel = None,
|
|
deepthought: bool = False,
|
|
user: KhojUser = None,
|
|
query_images: List[str] = None,
|
|
context: str = "",
|
|
query_files: str = None,
|
|
conversation_log: dict = {},
|
|
agent_chat_model: ChatModel = None,
|
|
tracer: dict = {},
|
|
):
|
|
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
|
|
vision_available = chat_model.vision_enabled
|
|
if not vision_available and query_images:
|
|
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
|
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
|
if vision_enabled_config:
|
|
chat_model = vision_enabled_config
|
|
vision_available = True
|
|
if vision_available and query_images:
|
|
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
|
|
|
|
subscribed = await ais_user_subscribed(user) if user else False
|
|
max_tokens = (
|
|
chat_model.subscribed_max_prompt_size
|
|
if subscribed and chat_model.subscribed_max_prompt_size
|
|
else chat_model.max_prompt_size
|
|
)
|
|
chat_model_name = chat_model.name
|
|
tokenizer = chat_model.tokenizer
|
|
model_type = chat_model.model_type
|
|
vision_available = chat_model.vision_enabled
|
|
api_key = chat_model.ai_model_api.api_key
|
|
api_base_url = chat_model.ai_model_api.api_base_url
|
|
loaded_model = None
|
|
|
|
if model_type == ChatModel.ModelType.OFFLINE:
|
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
|
|
truncated_messages = generate_chatml_messages_with_context(
|
|
user_message=query,
|
|
context_message=context,
|
|
system_message=system_message,
|
|
conversation_log=conversation_log,
|
|
model_name=chat_model_name,
|
|
loaded_model=loaded_model,
|
|
tokenizer_name=tokenizer,
|
|
max_prompt_size=max_tokens,
|
|
vision_enabled=vision_available,
|
|
query_images=query_images,
|
|
model_type=model_type,
|
|
query_files=query_files,
|
|
)
|
|
|
|
if model_type == ChatModel.ModelType.OFFLINE:
|
|
return send_message_to_model_offline(
|
|
messages=truncated_messages,
|
|
loaded_model=loaded_model,
|
|
model_name=chat_model_name,
|
|
max_prompt_size=max_tokens,
|
|
streaming=False,
|
|
response_type=response_type,
|
|
tracer=tracer,
|
|
)
|
|
|
|
elif model_type == ChatModel.ModelType.OPENAI:
|
|
return send_message_to_model(
|
|
messages=truncated_messages,
|
|
api_key=api_key,
|
|
model=chat_model_name,
|
|
response_type=response_type,
|
|
response_schema=response_schema,
|
|
deepthought=deepthought,
|
|
api_base_url=api_base_url,
|
|
tracer=tracer,
|
|
)
|
|
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
|
return anthropic_send_message_to_model(
|
|
messages=truncated_messages,
|
|
api_key=api_key,
|
|
model=chat_model_name,
|
|
response_type=response_type,
|
|
deepthought=deepthought,
|
|
api_base_url=api_base_url,
|
|
tracer=tracer,
|
|
)
|
|
elif model_type == ChatModel.ModelType.GOOGLE:
|
|
return gemini_send_message_to_model(
|
|
messages=truncated_messages,
|
|
api_key=api_key,
|
|
model=chat_model_name,
|
|
response_type=response_type,
|
|
response_schema=response_schema,
|
|
deepthought=deepthought,
|
|
api_base_url=api_base_url,
|
|
tracer=tracer,
|
|
)
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
|
|
|
|
|
def send_message_to_model_wrapper_sync(
|
|
message: str,
|
|
system_message: str = "",
|
|
response_type: str = "text",
|
|
response_schema: BaseModel = None,
|
|
user: KhojUser = None,
|
|
query_images: List[str] = None,
|
|
query_files: str = "",
|
|
conversation_log: dict = {},
|
|
tracer: dict = {},
|
|
):
|
|
chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
|
|
|
|
if chat_model is None:
|
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
|
|
|
subscribed = is_user_subscribed(user) if user else False
|
|
max_tokens = (
|
|
chat_model.subscribed_max_prompt_size
|
|
if subscribed and chat_model.subscribed_max_prompt_size
|
|
else chat_model.max_prompt_size
|
|
)
|
|
chat_model_name = chat_model.name
|
|
model_type = chat_model.model_type
|
|
vision_available = chat_model.vision_enabled
|
|
api_key = chat_model.ai_model_api.api_key
|
|
api_base_url = chat_model.ai_model_api.api_base_url
|
|
loaded_model = None
|
|
|
|
if model_type == ChatModel.ModelType.OFFLINE:
|
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
|
|
truncated_messages = generate_chatml_messages_with_context(
|
|
user_message=message,
|
|
system_message=system_message,
|
|
conversation_log=conversation_log,
|
|
model_name=chat_model_name,
|
|
loaded_model=loaded_model,
|
|
max_prompt_size=max_tokens,
|
|
vision_enabled=vision_available,
|
|
model_type=model_type,
|
|
query_images=query_images,
|
|
query_files=query_files,
|
|
)
|
|
|
|
if model_type == ChatModel.ModelType.OFFLINE:
|
|
return send_message_to_model_offline(
|
|
messages=truncated_messages,
|
|
loaded_model=loaded_model,
|
|
model_name=chat_model_name,
|
|
max_prompt_size=max_tokens,
|
|
streaming=False,
|
|
response_type=response_type,
|
|
tracer=tracer,
|
|
)
|
|
|
|
elif model_type == ChatModel.ModelType.OPENAI:
|
|
return send_message_to_model(
|
|
messages=truncated_messages,
|
|
api_key=api_key,
|
|
api_base_url=api_base_url,
|
|
model=chat_model_name,
|
|
response_type=response_type,
|
|
response_schema=response_schema,
|
|
tracer=tracer,
|
|
)
|
|
|
|
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
|
return anthropic_send_message_to_model(
|
|
messages=truncated_messages,
|
|
api_key=api_key,
|
|
api_base_url=api_base_url,
|
|
model=chat_model_name,
|
|
response_type=response_type,
|
|
tracer=tracer,
|
|
)
|
|
|
|
elif model_type == ChatModel.ModelType.GOOGLE:
|
|
return gemini_send_message_to_model(
|
|
messages=truncated_messages,
|
|
api_key=api_key,
|
|
api_base_url=api_base_url,
|
|
model=chat_model_name,
|
|
response_type=response_type,
|
|
response_schema=response_schema,
|
|
tracer=tracer,
|
|
)
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
|
|
|
|
|
async def agenerate_chat_response(
|
|
q: str,
|
|
meta_log: dict,
|
|
conversation: Conversation,
|
|
compiled_references: List[Dict] = [],
|
|
online_results: Dict[str, Dict] = {},
|
|
code_results: Dict[str, Dict] = {},
|
|
operator_results: Dict[str, str] = {},
|
|
research_results: List[InformationCollectionIteration] = [],
|
|
inferred_queries: List[str] = [],
|
|
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
|
user: KhojUser = None,
|
|
client_application: ClientApplication = None,
|
|
location_data: LocationData = None,
|
|
user_name: Optional[str] = None,
|
|
query_images: Optional[List[str]] = None,
|
|
train_of_thought: List[Any] = [],
|
|
query_files: str = None,
|
|
raw_query_files: List[FileAttachment] = None,
|
|
generated_images: List[str] = None,
|
|
raw_generated_files: List[FileAttachment] = [],
|
|
generated_mermaidjs_diagram: str = None,
|
|
program_execution_context: List[str] = [],
|
|
generated_asset_results: Dict[str, Dict] = {},
|
|
is_subscribed: bool = False,
|
|
tracer: dict = {},
|
|
) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]:
|
|
# Initialize Variables
|
|
chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None
|
|
logger.debug(f"Conversation Types: {conversation_commands}")
|
|
|
|
metadata = {}
|
|
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
|
|
|
try:
|
|
partial_completion = partial(
|
|
save_to_conversation_log,
|
|
q,
|
|
user=user,
|
|
meta_log=meta_log,
|
|
compiled_references=compiled_references,
|
|
online_results=online_results,
|
|
code_results=code_results,
|
|
operator_results=operator_results,
|
|
inferred_queries=inferred_queries,
|
|
client_application=client_application,
|
|
conversation_id=str(conversation.id),
|
|
query_images=query_images,
|
|
train_of_thought=train_of_thought,
|
|
raw_query_files=raw_query_files,
|
|
generated_images=generated_images,
|
|
raw_generated_files=raw_generated_files,
|
|
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
|
tracer=tracer,
|
|
)
|
|
|
|
query_to_run = q
|
|
deepthought = False
|
|
if research_results:
|
|
compiled_research = "".join([r.summarizedResult for r in research_results if r.summarizedResult])
|
|
if compiled_research:
|
|
query_to_run = f"<query>{q}</query>\n<collected_research>\n{compiled_research}\n</collected_research>"
|
|
compiled_references = []
|
|
online_results = {}
|
|
code_results = {}
|
|
operator_results = {}
|
|
deepthought = True
|
|
|
|
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
|
|
vision_available = chat_model.vision_enabled
|
|
if not vision_available and query_images:
|
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
|
if vision_enabled_config:
|
|
chat_model = vision_enabled_config
|
|
vision_available = True
|
|
|
|
if chat_model.model_type == "offline":
|
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
chat_response_generator = converse_offline(
|
|
user_query=query_to_run,
|
|
references=compiled_references,
|
|
online_results=online_results,
|
|
loaded_model=loaded_model,
|
|
conversation_log=meta_log,
|
|
completion_func=partial_completion,
|
|
conversation_commands=conversation_commands,
|
|
model_name=chat_model.name,
|
|
max_prompt_size=chat_model.max_prompt_size,
|
|
tokenizer_name=chat_model.tokenizer,
|
|
location_data=location_data,
|
|
user_name=user_name,
|
|
agent=agent,
|
|
query_files=query_files,
|
|
generated_files=raw_generated_files,
|
|
generated_asset_results=generated_asset_results,
|
|
tracer=tracer,
|
|
)
|
|
|
|
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
|
openai_chat_config = chat_model.ai_model_api
|
|
api_key = openai_chat_config.api_key
|
|
chat_model_name = chat_model.name
|
|
chat_response_generator = converse_openai(
|
|
query_to_run,
|
|
compiled_references,
|
|
query_images=query_images,
|
|
online_results=online_results,
|
|
code_results=code_results,
|
|
operator_results=operator_results,
|
|
conversation_log=meta_log,
|
|
model=chat_model_name,
|
|
api_key=api_key,
|
|
api_base_url=openai_chat_config.api_base_url,
|
|
completion_func=partial_completion,
|
|
conversation_commands=conversation_commands,
|
|
max_prompt_size=chat_model.max_prompt_size,
|
|
tokenizer_name=chat_model.tokenizer,
|
|
location_data=location_data,
|
|
user_name=user_name,
|
|
agent=agent,
|
|
vision_available=vision_available,
|
|
query_files=query_files,
|
|
generated_files=raw_generated_files,
|
|
generated_asset_results=generated_asset_results,
|
|
program_execution_context=program_execution_context,
|
|
deepthought=deepthought,
|
|
tracer=tracer,
|
|
)
|
|
|
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
|
api_key = chat_model.ai_model_api.api_key
|
|
api_base_url = chat_model.ai_model_api.api_base_url
|
|
chat_response_generator = converse_anthropic(
|
|
query_to_run,
|
|
compiled_references,
|
|
query_images=query_images,
|
|
online_results=online_results,
|
|
code_results=code_results,
|
|
operator_results=operator_results,
|
|
conversation_log=meta_log,
|
|
model=chat_model.name,
|
|
api_key=api_key,
|
|
api_base_url=api_base_url,
|
|
completion_func=partial_completion,
|
|
conversation_commands=conversation_commands,
|
|
max_prompt_size=chat_model.max_prompt_size,
|
|
tokenizer_name=chat_model.tokenizer,
|
|
location_data=location_data,
|
|
user_name=user_name,
|
|
agent=agent,
|
|
vision_available=vision_available,
|
|
query_files=query_files,
|
|
generated_files=raw_generated_files,
|
|
generated_asset_results=generated_asset_results,
|
|
program_execution_context=program_execution_context,
|
|
deepthought=deepthought,
|
|
tracer=tracer,
|
|
)
|
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
|
api_key = chat_model.ai_model_api.api_key
|
|
api_base_url = chat_model.ai_model_api.api_base_url
|
|
chat_response_generator = converse_gemini(
|
|
query_to_run,
|
|
compiled_references,
|
|
online_results=online_results,
|
|
code_results=code_results,
|
|
operator_results=operator_results,
|
|
conversation_log=meta_log,
|
|
model=chat_model.name,
|
|
api_key=api_key,
|
|
api_base_url=api_base_url,
|
|
completion_func=partial_completion,
|
|
conversation_commands=conversation_commands,
|
|
max_prompt_size=chat_model.max_prompt_size,
|
|
tokenizer_name=chat_model.tokenizer,
|
|
location_data=location_data,
|
|
user_name=user_name,
|
|
agent=agent,
|
|
query_images=query_images,
|
|
vision_available=vision_available,
|
|
query_files=query_files,
|
|
generated_files=raw_generated_files,
|
|
generated_asset_results=generated_asset_results,
|
|
program_execution_context=program_execution_context,
|
|
deepthought=deepthought,
|
|
tracer=tracer,
|
|
)
|
|
|
|
metadata.update({"chat_model": chat_model.name})
|
|
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# Return the generator directly
|
|
return chat_response_generator, metadata
|
|
|
|
|
|
class DeleteMessageRequestBody(BaseModel):
|
|
conversation_id: str
|
|
turn_id: str
|
|
|
|
|
|
class FeedbackData(BaseModel):
|
|
uquery: str
|
|
kquery: str
|
|
sentiment: str
|
|
|
|
|
|
class MagicLinkForm(BaseModel):
|
|
email: EmailStr
|
|
|
|
|
|
class EmailAttemptRateLimiter:
|
|
"""Rate limiter for email attempts BEFORE get/create user with valid email address."""
|
|
|
|
def __init__(self, requests: int, window: int, slug: str):
|
|
self.requests = requests
|
|
self.window = window # Window in seconds
|
|
self.slug = slug
|
|
|
|
async def __call__(self, form: MagicLinkForm):
|
|
# Disable login rate limiting in debug mode
|
|
if in_debug_mode():
|
|
return
|
|
|
|
# Calculate the time window cutoff
|
|
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
|
|
|
# Count recent attempts for this email and slug
|
|
count = await RateLimitRecord.objects.filter(
|
|
identifier=form.email, slug=self.slug, created_at__gte=cutoff
|
|
).acount()
|
|
|
|
if count >= self.requests:
|
|
logger.warning(f"Email attempt rate limit exceeded for {form.email} (slug: {self.slug})")
|
|
raise HTTPException(
|
|
status_code=429, detail="Too many requests for your email address. Please wait before trying again."
|
|
)
|
|
|
|
# Record the current attempt
|
|
await RateLimitRecord.objects.acreate(identifier=form.email, slug=self.slug)
|
|
|
|
|
|
class EmailVerificationApiRateLimiter:
|
|
"""Rate limiter for actions AFTER user with valid email address is known to exist"""
|
|
|
|
def __init__(self, requests: int, window: int, slug: str):
|
|
self.requests = requests
|
|
self.window = window # Window in seconds
|
|
self.slug = slug
|
|
|
|
async def __call__(self, email: str = None):
|
|
# Disable login rate limiting in debug mode
|
|
if in_debug_mode():
|
|
return
|
|
|
|
user: KhojUser = await aget_user_by_email(email)
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="User not found.")
|
|
|
|
# Remove requests outside of the time window
|
|
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
|
count_requests = await UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).acount()
|
|
|
|
# Check if the user has exceeded the rate limit
|
|
if count_requests >= self.requests:
|
|
logger.warning(
|
|
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
|
|
)
|
|
raise HTTPException(status_code=429, detail="Ran out of login attempts. Please wait before trying again.")
|
|
|
|
# Add the current request to the db
|
|
await UserRequests.objects.acreate(user=user, slug=self.slug)
|
|
|
|
|
|
class ApiUserRateLimiter:
|
|
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
|
self.requests = requests
|
|
self.subscribed_requests = subscribed_requests
|
|
self.window = window
|
|
self.slug = slug
|
|
|
|
def __call__(self, request: Request):
|
|
# Rate limiting disabled if billing is disabled
|
|
if state.billing_enabled is False:
|
|
return
|
|
|
|
# Rate limiting is disabled if user unauthenticated.
|
|
# Other systems handle authentication
|
|
if not request.user.is_authenticated:
|
|
return
|
|
|
|
user: KhojUser = request.user.object
|
|
subscribed = has_required_scope(request, ["premium"])
|
|
|
|
# Remove requests outside of the time window
|
|
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
|
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
|
|
|
# Check if the user has exceeded the rate limit
|
|
if subscribed and count_requests >= self.subscribed_requests:
|
|
logger.info(
|
|
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
|
|
)
|
|
if not subscribed and count_requests >= self.requests:
|
|
if self.requests >= self.subscribed_requests:
|
|
logger.info(
|
|
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
|
|
)
|
|
|
|
logger.info(
|
|
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
|
|
)
|
|
|
|
# Add the current request to the cache
|
|
UserRequests.objects.create(user=user, slug=self.slug)
|
|
|
|
|
|
class ApiImageRateLimiter:
|
|
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
|
self.max_images = max_images
|
|
self.max_combined_size_mb = max_combined_size_mb
|
|
|
|
def __call__(self, request: Request, body: ChatRequestBody):
|
|
if state.billing_enabled is False:
|
|
return
|
|
|
|
# Rate limiting is disabled if user unauthenticated.
|
|
# Other systems handle authentication
|
|
if not request.user.is_authenticated:
|
|
return
|
|
|
|
if not body.images:
|
|
return
|
|
|
|
# Check number of images
|
|
if len(body.images) > self.max_images:
|
|
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
|
)
|
|
|
|
# Check total size of images
|
|
total_size_mb = 0.0
|
|
for image in body.images:
|
|
# Unquote the image in case it's URL encoded
|
|
image = unquote(image)
|
|
# Assuming the image is a base64 encoded string
|
|
# Remove the data:image/jpeg;base64, part if present
|
|
if "," in image:
|
|
image = image.split(",", 1)[1]
|
|
|
|
# Decode base64 to get the actual size
|
|
image_bytes = base64.b64decode(image)
|
|
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
|
|
|
if total_size_mb > self.max_combined_size_mb:
|
|
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
|
)
|
|
|
|
|
|
class ConversationCommandRateLimiter:
|
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
|
self.slug = slug
|
|
self.trial_rate_limit = trial_rate_limit
|
|
self.subscribed_rate_limit = subscribed_rate_limit
|
|
self.restricted_commands = [ConversationCommand.Research]
|
|
|
|
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
|
if state.billing_enabled is False:
|
|
return
|
|
|
|
if not request.user.is_authenticated:
|
|
return
|
|
|
|
if conversation_command not in self.restricted_commands:
|
|
return
|
|
|
|
user: KhojUser = request.user.object
|
|
subscribed = has_required_scope(request, ["premium"])
|
|
|
|
# Remove requests outside of the 24-hr time window
|
|
cutoff = django_timezone.now() - timedelta(seconds=60 * 60 * 24)
|
|
command_slug = f"{self.slug}_{conversation_command.value}"
|
|
count_requests = await UserRequests.objects.filter(
|
|
user=user, created_at__gte=cutoff, slug=command_slug
|
|
).acount()
|
|
|
|
if subscribed and count_requests >= self.subscribed_rate_limit:
|
|
logger.info(
|
|
f"Rate limit: {count_requests}/{self.subscribed_rate_limit} requests not allowed in 24 hours for subscribed user: {user}."
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your `/{conversation_command.value}` command usage limit for today. Maybe we can talk about something else for today?",
|
|
)
|
|
if not subscribed and count_requests >= self.trial_rate_limit:
|
|
logger.info(
|
|
f"Rate limit: {count_requests}/{self.trial_rate_limit} requests not allowed in 24 hours for user: {user}."
|
|
)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your `/{conversation_command.value}` command usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can talk about something else for today?",
|
|
)
|
|
await UserRequests.objects.acreate(user=user, slug=command_slug)
|
|
return
|
|
|
|
|
|
class ApiIndexedDataLimiter:
|
|
def __init__(
|
|
self,
|
|
incoming_entries_size_limit: float,
|
|
subscribed_incoming_entries_size_limit: float,
|
|
total_entries_size_limit: float,
|
|
subscribed_total_entries_size_limit: float,
|
|
):
|
|
self.num_entries_size = incoming_entries_size_limit
|
|
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
|
|
self.total_entries_size_limit = total_entries_size_limit
|
|
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
|
|
|
|
def __call__(self, request: Request, files: List[UploadFile] = None):
|
|
if state.billing_enabled is False:
|
|
return
|
|
|
|
subscribed = has_required_scope(request, ["premium"])
|
|
incoming_data_size_mb = 0.0
|
|
deletion_file_names = set()
|
|
|
|
if not request.user.is_authenticated or not files:
|
|
return
|
|
|
|
user: KhojUser = request.user.object
|
|
|
|
for file in files:
|
|
if file.size == 0:
|
|
deletion_file_names.add(file.filename)
|
|
|
|
incoming_data_size_mb += file.size / 1024 / 1024
|
|
|
|
num_deleted_entries = 0
|
|
for file_path in deletion_file_names:
|
|
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
|
num_deleted_entries += deleted_count
|
|
|
|
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
|
|
|
|
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
|
|
logger.info(
|
|
f"Data limit: {incoming_data_size_mb}MB incoming will exceed {self.subscribed_num_entries_size}MB allowed for subscribed user: {user}."
|
|
)
|
|
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
|
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
|
|
logger.info(
|
|
f"Data limit: {incoming_data_size_mb}MB incoming will exceed {self.num_entries_size}MB allowed for user: {user}."
|
|
)
|
|
raise HTTPException(
|
|
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
|
)
|
|
|
|
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
|
|
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
|
|
logger.info(
|
|
f"Data limit: {incoming_data_size_mb}MB incoming + {user_size_data}MB existing will exceed {self.subscribed_total_entries_size}MB allowed for subscribed user: {user}."
|
|
)
|
|
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
|
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
|
|
logger.info(
|
|
f"Data limit: {incoming_data_size_mb}MB incoming + {user_size_data}MB existing will exceed {self.subscribed_total_entries_size}MB allowed for non subscribed user: {user}."
|
|
)
|
|
raise HTTPException(
|
|
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
|
)
|
|
|
|
|
|
class CommonQueryParamsClass:
|
|
def __init__(
|
|
self,
|
|
client: Optional[str] = None,
|
|
user_agent: Optional[str] = Header(None),
|
|
referer: Optional[str] = Header(None),
|
|
host: Optional[str] = Header(None),
|
|
):
|
|
self.client = client
|
|
self.user_agent = user_agent
|
|
self.referer = referer
|
|
self.host = host
|
|
|
|
|
|
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
|
|
|
|
|
def format_automation_response(scheduling_request: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
|
|
"""
|
|
Format the AI response to send in automation email to user.
|
|
"""
|
|
name = get_user_name(user)
|
|
username = prompts.user_name.format(name=name) if name else ""
|
|
|
|
automation_format_prompt = prompts.automation_format_prompt.format(
|
|
original_query=scheduling_request,
|
|
executed_query=executed_query,
|
|
response=ai_response,
|
|
username=username,
|
|
)
|
|
|
|
with timer("Chat actor: Format automation response", logger):
|
|
return send_message_to_model_wrapper_sync(automation_format_prompt, user=user)
|
|
|
|
|
|
def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
|
|
"""
|
|
Decide whether to notify the user of the AI response.
|
|
Default to notifying the user for now.
|
|
"""
|
|
if any(is_none_or_empty(message) for message in [original_query, executed_query, ai_response]):
|
|
return False
|
|
|
|
to_notify_or_not = prompts.to_notify_or_not.format(
|
|
original_query=original_query,
|
|
executed_query=executed_query,
|
|
response=ai_response,
|
|
)
|
|
|
|
with timer("Chat actor: Decide to notify user of automation response", logger):
|
|
try:
|
|
# TODO Replace with async call so we don't have to maintain a sync version
|
|
raw_response = send_message_to_model_wrapper_sync(to_notify_or_not, user=user, response_type="json_object")
|
|
response = json.loads(clean_json(raw_response))
|
|
should_notify_result = response["decision"] == "Yes"
|
|
reason = response.get("reason", "unknown")
|
|
logger.info(
|
|
f'Decided to {"not " if not should_notify_result else ""}notify user of automation response because of reason: {reason}.'
|
|
)
|
|
return should_notify_result
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Fallback to notify user of automation response as failed to infer should notify or not. {e}",
|
|
exc_info=True,
|
|
)
|
|
return True
|
|
|
|
|
|
def scheduled_chat(
|
|
query_to_run: str,
|
|
scheduling_request: str,
|
|
subject: str,
|
|
user: KhojUser,
|
|
calling_url: URL,
|
|
job_id: str = None,
|
|
conversation_id: str = None,
|
|
):
|
|
logger.info(f"Processing scheduled_chat: {query_to_run}")
|
|
if job_id:
|
|
# Get the job object and check whether the time is valid for it to run. This helps avoid race conditions that cause the same job to be run multiple times.
|
|
job = AutomationAdapters.get_automation(user, job_id)
|
|
last_run_time = AutomationAdapters.get_job_last_run(user, job)
|
|
|
|
# Convert last_run_time from %Y-%m-%d %I:%M %p %Z to datetime object
|
|
if last_run_time:
|
|
last_run_time = datetime.strptime(last_run_time, "%Y-%m-%d %I:%M %p %Z").replace(tzinfo=timezone.utc)
|
|
|
|
# If the last run time was within the last 6 hours, don't run it again. This helps avoid multithreading issues and rate limits.
|
|
if (datetime.now(timezone.utc) - last_run_time).total_seconds() < 6 * 60 * 60:
|
|
logger.info(f"Skipping scheduled chat {job_id} as the next run time is in the future.")
|
|
return
|
|
|
|
# Extract relevant params from the original URL
|
|
scheme = "http" if not calling_url.is_secure else "https"
|
|
query_dict = parse_qs(calling_url.query)
|
|
|
|
# Pop the stream value from query_dict if it exists
|
|
query_dict.pop("stream", None)
|
|
|
|
# Replace the original scheduling query with the scheduled query
|
|
query_dict["q"] = [query_to_run]
|
|
|
|
# Replace the original conversation_id with the conversation_id
|
|
if conversation_id:
|
|
# encode the conversation_id to avoid any issues with special characters
|
|
query_dict["conversation_id"] = [quote(str(conversation_id))]
|
|
|
|
# validate that the conversation id exists. If not, delete the automation and exit.
|
|
if not ConversationAdapters.get_conversation_by_id(conversation_id):
|
|
AutomationAdapters.delete_automation(user, job_id)
|
|
return
|
|
|
|
# Restructure the original query_dict into a valid JSON payload for the chat API
|
|
json_payload = {key: values[0] for key, values in query_dict.items()}
|
|
|
|
# Construct the URL to call the chat API with the scheduled query string
|
|
url = f"{scheme}://{calling_url.netloc}/api/chat?client=khoj"
|
|
|
|
# Construct the Headers for the chat API
|
|
headers = {"User-Agent": "Khoj", "Content-Type": "application/json"}
|
|
if not state.anonymous_mode:
|
|
# Add authorization request header in non-anonymous mode
|
|
token = get_khoj_tokens(user)
|
|
if is_none_or_empty(token):
|
|
token = create_khoj_token(user).token
|
|
else:
|
|
token = token[0].token
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
|
|
# Call the chat API endpoint with authenticated user token and query
|
|
raw_response = requests.post(url, headers=headers, json=json_payload, allow_redirects=False)
|
|
|
|
# Handle redirect manually if necessary
|
|
if raw_response.status_code in [301, 302, 308]:
|
|
redirect_url = raw_response.headers["Location"]
|
|
logger.info(f"Redirecting to {redirect_url}")
|
|
raw_response = requests.post(redirect_url, headers=headers, json=json_payload)
|
|
|
|
# Stop if the chat API call was not successful
|
|
if raw_response.status_code != 200:
|
|
logger.error(f"Failed to run schedule chat: {raw_response.text}, user: {user}, query: {query_to_run}")
|
|
return None
|
|
|
|
# Extract the AI response from the chat API response
|
|
cleaned_query = re.sub(r"^/automated_task\s*", "", query_to_run).strip()
|
|
is_image = False
|
|
if raw_response.headers.get("Content-Type") == "application/json":
|
|
response_map = raw_response.json()
|
|
ai_response = response_map.get("response") or response_map.get("image")
|
|
is_image = False
|
|
if type(ai_response) == dict:
|
|
is_image = ai_response.get("image") is not None
|
|
else:
|
|
ai_response = raw_response.text
|
|
|
|
# Notify user if the AI response is satisfactory
|
|
if should_notify(
|
|
original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user
|
|
):
|
|
formatted_response = format_automation_response(scheduling_request, cleaned_query, ai_response, user)
|
|
|
|
if is_resend_enabled():
|
|
send_task_email(user.get_short_name(), user.email, cleaned_query, formatted_response, subject, is_image)
|
|
else:
|
|
return formatted_response
|
|
|
|
|
|
async def create_automation(
|
|
q: str,
|
|
timezone: str,
|
|
user: KhojUser,
|
|
calling_url: URL,
|
|
meta_log: dict = {},
|
|
conversation_id: str = None,
|
|
tracer: dict = {},
|
|
):
|
|
crontime, query_to_run, subject = await aschedule_query(q, meta_log, user, tracer=tracer)
|
|
job = await aschedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
|
return job, crontime, query_to_run, subject
|
|
|
|
|
|
def schedule_automation(
|
|
query_to_run: str,
|
|
subject: str,
|
|
crontime: str,
|
|
timezone: str,
|
|
scheduling_request: str,
|
|
user: KhojUser,
|
|
calling_url: URL,
|
|
conversation_id: str,
|
|
):
|
|
# Disable minute level automation recurrence
|
|
minute_value = crontime.split(" ")[0]
|
|
if not minute_value.isdigit():
|
|
# Run automation at some random minute (to distribute request load) instead of running every X minutes
|
|
crontime = " ".join([str(math.floor(random() * 60))] + crontime.split(" ")[1:])
|
|
|
|
# Convert timezone string to timezone object
|
|
try:
|
|
user_timezone = pytz.timezone(timezone)
|
|
except pytz.UnknownTimeZoneError:
|
|
logger.warning(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
|
|
user_timezone = pytz.utc
|
|
|
|
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
|
trigger.jitter = 60
|
|
# Generate id and metadata used by task scheduler and process locks for the task runs
|
|
job_metadata = json.dumps(
|
|
{
|
|
"query_to_run": query_to_run,
|
|
"scheduling_request": scheduling_request,
|
|
"subject": subject,
|
|
"crontime": crontime,
|
|
"conversation_id": str(conversation_id),
|
|
}
|
|
)
|
|
query_id = hashlib.md5(f"{query_to_run}_{crontime}".encode("utf-8")).hexdigest()
|
|
job_id = f"automation_{user.uuid}_{query_id}"
|
|
job = state.scheduler.add_job(
|
|
run_with_process_lock,
|
|
trigger=trigger,
|
|
args=(
|
|
scheduled_chat,
|
|
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
|
|
),
|
|
kwargs={
|
|
"query_to_run": query_to_run,
|
|
"scheduling_request": scheduling_request,
|
|
"subject": subject,
|
|
"user": user,
|
|
"calling_url": calling_url,
|
|
"job_id": job_id,
|
|
"conversation_id": conversation_id,
|
|
},
|
|
id=job_id,
|
|
name=job_metadata,
|
|
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
|
)
|
|
return job
|
|
|
|
|
|
async def aschedule_automation(
|
|
query_to_run: str,
|
|
subject: str,
|
|
crontime: str,
|
|
timezone: str,
|
|
scheduling_request: str,
|
|
user: KhojUser,
|
|
calling_url: URL,
|
|
conversation_id: str,
|
|
):
|
|
# Disable minute level automation recurrence
|
|
minute_value = crontime.split(" ")[0]
|
|
if not minute_value.isdigit():
|
|
# Run automation at some random minute (to distribute request load) instead of running every X minutes
|
|
crontime = " ".join([str(math.floor(random() * 60))] + crontime.split(" ")[1:])
|
|
|
|
user_timezone = pytz.timezone(timezone)
|
|
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
|
trigger.jitter = 60
|
|
# Generate id and metadata used by task scheduler and process locks for the task runs
|
|
job_metadata = json.dumps(
|
|
{
|
|
"query_to_run": query_to_run,
|
|
"scheduling_request": scheduling_request,
|
|
"subject": subject,
|
|
"crontime": crontime,
|
|
"conversation_id": str(conversation_id),
|
|
}
|
|
)
|
|
query_id = hashlib.md5(f"{query_to_run}_{crontime}".encode("utf-8")).hexdigest()
|
|
job_id = f"automation_{user.uuid}_{query_id}"
|
|
job = await sync_to_async(state.scheduler.add_job)(
|
|
run_with_process_lock,
|
|
trigger=trigger,
|
|
args=(
|
|
scheduled_chat,
|
|
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
|
|
),
|
|
kwargs={
|
|
"query_to_run": query_to_run,
|
|
"scheduling_request": scheduling_request,
|
|
"subject": subject,
|
|
"user": user,
|
|
"calling_url": calling_url,
|
|
"job_id": job_id,
|
|
"conversation_id": conversation_id,
|
|
},
|
|
id=job_id,
|
|
name=job_metadata,
|
|
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
|
)
|
|
return job
|
|
|
|
|
|
def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str):
|
|
# Display next run time in user timezone instead of UTC
|
|
schedule = f'{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime("%Z")}'
|
|
next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z")
|
|
# Remove /automated_task prefix from inferred_query
|
|
unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run)
|
|
# Create the automation response
|
|
automation_icon_url = f"/static/assets/icons/automation.svg"
|
|
return f"""
|
|
###  Created Automation
|
|
- Subject: **{subject}**
|
|
- Query to Run: "{unprefixed_query_to_run}"
|
|
- Schedule: `{schedule}`
|
|
- Next Run At: {next_run_time}
|
|
|
|
Manage your automations [here](/automations).
|
|
""".strip()
|
|
|
|
|
|
class MessageProcessor:
|
|
def __init__(self):
|
|
self.references = {}
|
|
self.usage = {}
|
|
self.raw_response = ""
|
|
self.generated_images = []
|
|
self.generated_files = []
|
|
self.generated_mermaidjs_diagram = []
|
|
|
|
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
|
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
|
|
try:
|
|
json_chunk = json.loads(raw_chunk)
|
|
if "type" not in json_chunk:
|
|
json_chunk = {"type": "message", "data": json_chunk}
|
|
return json_chunk
|
|
except json.JSONDecodeError:
|
|
return {"type": "message", "data": raw_chunk}
|
|
elif raw_chunk:
|
|
return {"type": "message", "data": raw_chunk}
|
|
return {"type": "", "data": ""}
|
|
|
|
def process_message_chunk(self, raw_chunk: str) -> None:
|
|
chunk = self.convert_message_chunk_to_json(raw_chunk)
|
|
if not chunk or not chunk["type"]:
|
|
return
|
|
|
|
chunk_type = ChatEvent(chunk["type"])
|
|
if chunk_type == ChatEvent.REFERENCES:
|
|
self.references = chunk["data"]
|
|
elif chunk_type == ChatEvent.USAGE:
|
|
self.usage = chunk["data"]
|
|
elif chunk_type == ChatEvent.MESSAGE:
|
|
chunk_data = chunk["data"]
|
|
if isinstance(chunk_data, dict):
|
|
self.raw_response = self.handle_json_response(chunk_data)
|
|
elif (
|
|
isinstance(chunk_data, str) and chunk_data.strip().startswith("{") and chunk_data.strip().endswith("}")
|
|
):
|
|
try:
|
|
json_data = json.loads(chunk_data.strip())
|
|
self.raw_response = self.handle_json_response(json_data)
|
|
except json.JSONDecodeError:
|
|
self.raw_response += chunk_data
|
|
else:
|
|
self.raw_response += chunk_data
|
|
elif chunk_type == ChatEvent.GENERATED_ASSETS:
|
|
chunk_data = chunk["data"]
|
|
if isinstance(chunk_data, dict):
|
|
for key in chunk_data:
|
|
if key == "images":
|
|
self.generated_images = chunk_data[key]
|
|
elif key == "files":
|
|
self.generated_files = chunk_data[key]
|
|
elif key == "mermaidjsDiagram":
|
|
self.generated_mermaidjs_diagram = chunk_data[key]
|
|
|
|
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
|
|
if "image" in json_data or "details" in json_data:
|
|
return json_data
|
|
if "response" in json_data:
|
|
return json_data["response"]
|
|
return json_data
|
|
|
|
|
|
async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict[str, Any]:
|
|
processor = MessageProcessor()
|
|
event_delimiter = "␃🔚␗"
|
|
buffer = ""
|
|
|
|
async for chunk in response_iterator:
|
|
# Start buffering chunks until complete event is received
|
|
buffer += chunk
|
|
|
|
# Once the buffer contains a complete event
|
|
while event_delimiter in buffer:
|
|
# Extract the event from the buffer
|
|
event, buffer = buffer.split(event_delimiter, 1)
|
|
# Process the event
|
|
if event:
|
|
processor.process_message_chunk(event)
|
|
|
|
# Process any remaining data in the buffer
|
|
if buffer:
|
|
processor.process_message_chunk(buffer)
|
|
|
|
return {
|
|
"response": processor.raw_response,
|
|
"references": processor.references,
|
|
"usage": processor.usage,
|
|
"images": processor.generated_images,
|
|
"files": processor.generated_files,
|
|
"mermaidjsDiagram": processor.generated_mermaidjs_diagram,
|
|
}
|
|
|
|
|
|
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
|
user_picture = request.session.get("user", {}).get("picture")
|
|
is_active = has_required_scope(request, ["premium"])
|
|
has_documents = EntryAdapters.user_has_entries(user=user)
|
|
|
|
if not is_detailed:
|
|
return {
|
|
"request": request,
|
|
"username": user.username if user else None,
|
|
"user_photo": user_picture,
|
|
"is_active": is_active,
|
|
"has_documents": has_documents,
|
|
"khoj_version": state.khoj_version,
|
|
}
|
|
|
|
user_subscription_state = get_user_subscription_state(user.email)
|
|
user_subscription = adapters.get_user_subscription(user.email)
|
|
|
|
subscription_renewal_date = (
|
|
user_subscription.renewal_date.strftime("%d %b %Y")
|
|
if user_subscription and user_subscription.renewal_date
|
|
else None
|
|
)
|
|
subscription_enabled_trial_at = (
|
|
user_subscription.enabled_trial_at.strftime("%d %b %Y")
|
|
if user_subscription and user_subscription.enabled_trial_at
|
|
else None
|
|
)
|
|
given_name = get_user_name(user)
|
|
|
|
enabled_content_sources_set = set(EntryAdapters.get_unique_file_sources(user))
|
|
enabled_content_sources = {
|
|
"computer": ("computer" in enabled_content_sources_set),
|
|
"github": ("github" in enabled_content_sources_set),
|
|
"notion": ("notion" in enabled_content_sources_set),
|
|
}
|
|
|
|
notion_oauth_url = get_notion_auth_url(user)
|
|
current_notion_config = get_user_notion_config(user)
|
|
notion_token = current_notion_config.token if current_notion_config else ""
|
|
|
|
selected_chat_model_config = ConversationAdapters.get_chat_model(
|
|
user
|
|
) or ConversationAdapters.get_default_chat_model(user)
|
|
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
|
chat_model_options = list()
|
|
for chat_model in chat_models:
|
|
chat_model_options.append(
|
|
{
|
|
"name": chat_model.name,
|
|
"id": chat_model.id,
|
|
"strengths": chat_model.strengths,
|
|
"description": chat_model.description,
|
|
"tier": chat_model.price_tier,
|
|
}
|
|
)
|
|
|
|
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
|
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
|
all_paint_model_options = list()
|
|
for paint_model in paint_model_options:
|
|
all_paint_model_options.append(
|
|
{
|
|
"name": paint_model.model_name,
|
|
"id": paint_model.id,
|
|
"tier": paint_model.price_tier,
|
|
}
|
|
)
|
|
|
|
voice_models = ConversationAdapters.get_voice_model_options()
|
|
voice_model_options = list()
|
|
for voice_model in voice_models:
|
|
voice_model_options.append(
|
|
{
|
|
"name": voice_model.name,
|
|
"id": voice_model.model_id,
|
|
"tier": voice_model.price_tier,
|
|
}
|
|
)
|
|
|
|
if len(voice_model_options) == 0:
|
|
eleven_labs_enabled = False
|
|
else:
|
|
eleven_labs_enabled = is_eleven_labs_enabled()
|
|
|
|
selected_voice_model_config = ConversationAdapters.get_voice_model_config(user)
|
|
|
|
return {
|
|
"request": request,
|
|
# user info
|
|
"username": user.username if user else None,
|
|
"user_photo": user_picture,
|
|
"is_active": is_active,
|
|
"given_name": given_name,
|
|
"phone_number": str(user.phone_number) if user.phone_number else "",
|
|
"is_phone_number_verified": user.verified_phone_number,
|
|
# user content settings
|
|
"enabled_content_source": enabled_content_sources,
|
|
"has_documents": has_documents,
|
|
"notion_token": notion_token,
|
|
# user model settings
|
|
"chat_model_options": chat_model_options,
|
|
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
|
|
"paint_model_options": all_paint_model_options,
|
|
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
|
|
"voice_model_options": voice_model_options,
|
|
"selected_voice_model_config": selected_voice_model_config.model_id if selected_voice_model_config else None,
|
|
# user billing info
|
|
"subscription_state": user_subscription_state,
|
|
"subscription_renewal_date": subscription_renewal_date,
|
|
"subscription_enabled_trial_at": subscription_enabled_trial_at,
|
|
# server settings
|
|
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
|
"billing_enabled": state.billing_enabled,
|
|
"is_eleven_labs_enabled": eleven_labs_enabled,
|
|
"is_twilio_enabled": is_twilio_enabled(),
|
|
"khoj_version": state.khoj_version,
|
|
"anonymous_mode": state.anonymous_mode,
|
|
"notion_oauth_url": notion_oauth_url,
|
|
"length_of_free_trial": LENGTH_OF_FREE_TRIAL,
|
|
}
|
|
|
|
|
|
def configure_content(
|
|
user: KhojUser,
|
|
files: Optional[dict[str, dict[str, str]]],
|
|
regenerate: bool = False,
|
|
t: Optional[state.SearchType] = state.SearchType.All,
|
|
) -> bool:
|
|
success = True
|
|
if t == None:
|
|
t = state.SearchType.All
|
|
|
|
if t is not None and t in [type.value for type in state.SearchType]:
|
|
t = state.SearchType(t)
|
|
|
|
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
|
logger.warning(f"🚨 Invalid search type: {t}")
|
|
return False
|
|
|
|
search_type = t.value if t else None
|
|
|
|
no_documents = all([not files.get(file_type) for file_type in files])
|
|
|
|
if files is None:
|
|
logger.warning(f"🚨 No files to process for {search_type} search.")
|
|
return True
|
|
|
|
try:
|
|
# Initialize Org Notes Search
|
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Org.value) and files["org"]:
|
|
logger.info("🦄 Setting up search for orgmode notes")
|
|
# Extract Entries, Generate Notes Embeddings
|
|
text_search.setup(
|
|
OrgToEntries,
|
|
files.get("org"),
|
|
regenerate=regenerate,
|
|
user=user,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
|
|
success = False
|
|
|
|
try:
|
|
# Initialize Markdown Search
|
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Markdown.value) and files[
|
|
"markdown"
|
|
]:
|
|
logger.info("💎 Setting up search for markdown notes")
|
|
# Extract Entries, Generate Markdown Embeddings
|
|
text_search.setup(
|
|
MarkdownToEntries,
|
|
files.get("markdown"),
|
|
regenerate=regenerate,
|
|
user=user,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"🚨 Failed to setup markdown: {e}", exc_info=True)
|
|
success = False
|
|
|
|
try:
|
|
# Initialize PDF Search
|
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Pdf.value) and files["pdf"]:
|
|
logger.info("🖨️ Setting up search for pdf")
|
|
# Extract Entries, Generate PDF Embeddings
|
|
text_search.setup(
|
|
PdfToEntries,
|
|
files.get("pdf"),
|
|
regenerate=regenerate,
|
|
user=user,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"🚨 Failed to setup PDF: {e}", exc_info=True)
|
|
success = False
|
|
|
|
try:
|
|
# Initialize Plaintext Search
|
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Plaintext.value) and files[
|
|
"plaintext"
|
|
]:
|
|
logger.info("📄 Setting up search for plaintext")
|
|
# Extract Entries, Generate Plaintext Embeddings
|
|
text_search.setup(
|
|
PlaintextToEntries,
|
|
files.get("plaintext"),
|
|
regenerate=regenerate,
|
|
user=user,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
|
|
success = False
|
|
|
|
try:
|
|
if no_documents:
|
|
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
|
if (
|
|
search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
|
|
) and github_config is not None:
|
|
logger.info("🐙 Setting up search for github")
|
|
# Extract Entries, Generate Github Embeddings
|
|
text_search.setup(
|
|
GithubToEntries,
|
|
None,
|
|
regenerate=regenerate,
|
|
user=user,
|
|
config=github_config,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
|
|
success = False
|
|
|
|
try:
|
|
if no_documents:
|
|
# Initialize Notion Search
|
|
notion_config = NotionConfig.objects.filter(user=user).first()
|
|
if (
|
|
search_type == state.SearchType.All.value or search_type == state.SearchType.Notion.value
|
|
) and notion_config:
|
|
logger.info("🔌 Setting up search for notion")
|
|
text_search.setup(
|
|
NotionToEntries,
|
|
None,
|
|
regenerate=regenerate,
|
|
user=user,
|
|
config=notion_config,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
|
|
success = False
|
|
|
|
try:
|
|
# Initialize Image Search
|
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
|
|
"image"
|
|
]:
|
|
logger.info("🖼️ Setting up search for images")
|
|
# Extract Entries, Generate Image Embeddings
|
|
text_search.setup(
|
|
ImageToEntries,
|
|
files.get("image"),
|
|
regenerate=regenerate,
|
|
user=user,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
|
|
success = False
|
|
try:
|
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]:
|
|
logger.info("📄 Setting up search for docx")
|
|
text_search.setup(
|
|
DocxToEntries,
|
|
files.get("docx"),
|
|
regenerate=regenerate,
|
|
user=user,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"🚨 Failed to setup docx: {e}", exc_info=True)
|
|
success = False
|
|
|
|
# Invalidate Query Cache
|
|
if user:
|
|
state.query_cache[user.uuid] = LRU()
|
|
|
|
return success
|
|
|
|
|
|
def get_notion_auth_url(user: KhojUser):
|
|
if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI:
|
|
return None
|
|
return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}"
|