mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
1368 lines
52 KiB
Python
1368 lines
52 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from datetime import datetime
|
|
from functools import partial
|
|
from typing import Any, Dict, List, Optional
|
|
from urllib.parse import unquote
|
|
|
|
from asgiref.sync import sync_to_async
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from fastapi.responses import RedirectResponse, Response, StreamingResponse
|
|
from starlette.authentication import has_required_scope, requires
|
|
|
|
from khoj.app.settings import ALLOWED_HOSTS
|
|
from khoj.database.adapters import (
|
|
AgentAdapters,
|
|
ConversationAdapters,
|
|
EntryAdapters,
|
|
PublicConversationAdapters,
|
|
aget_user_name,
|
|
)
|
|
from khoj.database.models import Agent, KhojUser
|
|
from khoj.processor.conversation import prompts
|
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
|
from khoj.processor.conversation.utils import (
|
|
ResponseWithThought,
|
|
defilter_query,
|
|
save_to_conversation_log,
|
|
)
|
|
from khoj.processor.image.generate import text_to_image
|
|
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
|
from khoj.processor.tools.online_search import (
|
|
deduplicate_organic_results,
|
|
read_webpages,
|
|
search_online,
|
|
)
|
|
from khoj.processor.tools.run_code import run_code
|
|
from khoj.routers.api import extract_references_and_questions
|
|
from khoj.routers.email import send_query_feedback
|
|
from khoj.routers.helpers import (
|
|
ApiImageRateLimiter,
|
|
ApiUserRateLimiter,
|
|
ChatEvent,
|
|
ChatRequestBody,
|
|
CommonQueryParams,
|
|
ConversationCommandRateLimiter,
|
|
DeleteMessageRequestBody,
|
|
FeedbackData,
|
|
acreate_title_from_history,
|
|
agenerate_chat_response,
|
|
aget_data_sources_and_output_format,
|
|
construct_automation_created_message,
|
|
create_automation,
|
|
gather_raw_query_files,
|
|
generate_mermaidjs_diagram,
|
|
generate_summary_from_files,
|
|
get_conversation_command,
|
|
is_query_empty,
|
|
is_ready_to_chat,
|
|
read_chat_stream,
|
|
update_telemetry_state,
|
|
validate_chat_model,
|
|
)
|
|
from khoj.routers.research import (
|
|
InformationCollectionIteration,
|
|
execute_information_collection,
|
|
)
|
|
from khoj.routers.storage import upload_user_image_to_bucket
|
|
from khoj.utils import state
|
|
from khoj.utils.helpers import (
|
|
ConversationCommand,
|
|
command_descriptions,
|
|
convert_image_to_webp,
|
|
get_country_code_from_timezone,
|
|
get_country_name_from_timezone,
|
|
get_device,
|
|
is_none_or_empty,
|
|
)
|
|
from khoj.utils.rawconfig import (
|
|
ChatRequestBody,
|
|
FileAttachment,
|
|
FileFilterRequest,
|
|
FilesFilterRequest,
|
|
LocationData,
|
|
)
|
|
|
|
# Initialize Router
|
|
logger = logging.getLogger(__name__)
|
|
conversation_command_rate_limiter = ConversationCommandRateLimiter(
|
|
trial_rate_limit=20, subscribed_rate_limit=75, slug="command"
|
|
)
|
|
|
|
|
|
api_chat = APIRouter()
|
|
|
|
|
|
@api_chat.get("/stats", response_class=Response)
|
|
@requires(["authenticated"])
|
|
def chat_stats(request: Request, common: CommonQueryParams) -> Response:
|
|
num_conversations = ConversationAdapters.get_num_conversations(request.user.object)
|
|
return Response(
|
|
content=json.dumps({"num_conversations": num_conversations}), media_type="application/json", status_code=200
|
|
)
|
|
|
|
|
|
@api_chat.get("/export", response_class=Response)
|
|
@requires(["authenticated"])
|
|
def export_conversation(request: Request, common: CommonQueryParams, page: Optional[int] = 1) -> Response:
|
|
all_conversations = ConversationAdapters.get_all_conversations_for_export(request.user.object, page=page)
|
|
return Response(content=json.dumps(all_conversations), media_type="application/json", status_code=200)
|
|
|
|
|
|
@api_chat.get("/conversation/file-filters/{conversation_id}", response_class=Response)
|
|
@requires(["authenticated"])
|
|
def get_file_filter(request: Request, conversation_id: str) -> Response:
|
|
conversation = ConversationAdapters.get_conversation_by_user(request.user.object, conversation_id=conversation_id)
|
|
if not conversation:
|
|
return Response(content=json.dumps({"status": "error", "message": "Conversation not found"}), status_code=404)
|
|
|
|
# get all files from "computer"
|
|
file_list = EntryAdapters.get_all_filenames_by_source(request.user.object, "computer")
|
|
file_filters = []
|
|
for file in conversation.file_filters:
|
|
if file in file_list:
|
|
file_filters.append(file)
|
|
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
|
|
|
|
|
@api_chat.delete("/conversation/file-filters/bulk", response_class=Response)
|
|
@requires(["authenticated"])
|
|
def remove_files_filter(request: Request, filter: FilesFilterRequest) -> Response:
|
|
conversation_id = filter.conversation_id
|
|
files_filter = filter.filenames
|
|
file_filters = ConversationAdapters.remove_files_from_filter(request.user.object, conversation_id, files_filter)
|
|
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
|
|
|
|
|
@api_chat.post("/conversation/file-filters/bulk", response_class=Response)
|
|
@requires(["authenticated"])
|
|
def add_files_filter(request: Request, filter: FilesFilterRequest):
|
|
try:
|
|
conversation_id = filter.conversation_id
|
|
files_filter = filter.filenames
|
|
file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
|
|
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
|
except Exception as e:
|
|
logger.error(f"Error adding file filter {filter.filenames}: {e}", exc_info=True)
|
|
raise HTTPException(status_code=422, detail=str(e))
|
|
|
|
|
|
@api_chat.post("/conversation/file-filters", response_class=Response)
|
|
@requires(["authenticated"])
|
|
def add_file_filter(request: Request, filter: FileFilterRequest):
|
|
try:
|
|
conversation_id = filter.conversation_id
|
|
files_filter = [filter.filename]
|
|
file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
|
|
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
|
except Exception as e:
|
|
logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True)
|
|
raise HTTPException(status_code=422, detail=str(e))
|
|
|
|
|
|
@api_chat.delete("/conversation/file-filters", response_class=Response)
|
|
@requires(["authenticated"])
|
|
def remove_file_filter(request: Request, filter: FileFilterRequest) -> Response:
|
|
conversation_id = filter.conversation_id
|
|
files_filter = [filter.filename]
|
|
file_filters = ConversationAdapters.remove_files_from_filter(request.user.object, conversation_id, files_filter)
|
|
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
|
|
|
|
|
@api_chat.post("/feedback")
|
|
@requires(["authenticated"])
|
|
async def sendfeedback(request: Request, data: FeedbackData):
|
|
user: KhojUser = request.user.object
|
|
await send_query_feedback(data.uquery, data.kquery, data.sentiment, user.email)
|
|
|
|
|
|
@api_chat.post("/speech")
|
|
@requires(["authenticated"])
|
|
async def text_to_speech(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
text: str,
|
|
rate_limiter_per_minute=Depends(
|
|
ApiUserRateLimiter(requests=30, subscribed_requests=30, window=60, slug="chat_minute")
|
|
),
|
|
rate_limiter_per_day=Depends(
|
|
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
|
),
|
|
) -> Response:
|
|
voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object)
|
|
|
|
params = {"text_to_speak": text}
|
|
|
|
if voice_model:
|
|
params["voice_id"] = voice_model.model_id
|
|
|
|
speech_stream = generate_text_to_speech(**params)
|
|
return StreamingResponse(speech_stream.iter_content(chunk_size=1024), media_type="audio/mpeg")
|
|
|
|
|
|
@api_chat.get("/starters", response_class=Response)
|
|
@requires(["authenticated"])
|
|
async def chat_starters(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
) -> Response:
|
|
user: KhojUser = request.user.object
|
|
starter_questions = await ConversationAdapters.aget_conversation_starters(user)
|
|
return Response(content=json.dumps(starter_questions), media_type="application/json", status_code=200)
|
|
|
|
|
|
@api_chat.get("/history")
|
|
@requires(["authenticated"])
|
|
def chat_history(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
conversation_id: Optional[str] = None,
|
|
n: Optional[int] = None,
|
|
):
|
|
user = request.user.object
|
|
validate_chat_model(user)
|
|
|
|
# Load Conversation History
|
|
conversation = ConversationAdapters.get_conversation_by_user(
|
|
user=user, client_application=request.user.client_app, conversation_id=conversation_id
|
|
)
|
|
|
|
if conversation is None:
|
|
return Response(
|
|
content=json.dumps({"status": "error", "message": f"Conversation: {conversation_id} not found"}),
|
|
status_code=404,
|
|
)
|
|
|
|
agent_metadata = None
|
|
if conversation.agent:
|
|
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE and conversation.agent.creator != user:
|
|
conversation.agent = None
|
|
else:
|
|
agent_metadata = {
|
|
"slug": conversation.agent.slug,
|
|
"name": conversation.agent.name,
|
|
"is_creator": conversation.agent.creator == user,
|
|
"color": conversation.agent.style_color,
|
|
"icon": conversation.agent.style_icon,
|
|
"persona": conversation.agent.personality,
|
|
"is_hidden": conversation.agent.is_hidden,
|
|
}
|
|
|
|
meta_log = conversation.conversation_log
|
|
meta_log.update(
|
|
{
|
|
"conversation_id": conversation.id,
|
|
"slug": conversation.title if conversation.title else conversation.slug,
|
|
"agent": agent_metadata,
|
|
"is_owner": conversation.user == user,
|
|
}
|
|
)
|
|
|
|
if n:
|
|
# Get latest N messages if N > 0
|
|
if n > 0 and meta_log.get("chat"):
|
|
meta_log["chat"] = meta_log["chat"][-n:]
|
|
# Else return all messages except latest N
|
|
elif n < 0 and meta_log.get("chat"):
|
|
meta_log["chat"] = meta_log["chat"][:n]
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="chat_history",
|
|
**common.__dict__,
|
|
)
|
|
|
|
return {"status": "ok", "response": meta_log}
|
|
|
|
|
|
@api_chat.get("/share/history")
|
|
def get_shared_chat(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
public_conversation_slug: str,
|
|
n: Optional[int] = None,
|
|
):
|
|
user = request.user.object if request.user.is_authenticated else None
|
|
|
|
# Load Conversation History
|
|
conversation = PublicConversationAdapters.get_public_conversation_by_slug(public_conversation_slug)
|
|
|
|
if conversation is None:
|
|
return Response(
|
|
content=json.dumps({"status": "error", "message": f"Conversation: {public_conversation_slug} not found"}),
|
|
status_code=404,
|
|
)
|
|
|
|
agent_metadata = None
|
|
if conversation.agent:
|
|
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE and conversation.agent.creator != user:
|
|
if conversation.agent.is_hidden:
|
|
default_agent = AgentAdapters.get_default_agent()
|
|
agent_metadata = {
|
|
"slug": default_agent.slug,
|
|
"name": default_agent.name,
|
|
"is_creator": False,
|
|
"color": default_agent.style_color,
|
|
"icon": default_agent.style_icon,
|
|
"persona": default_agent.personality,
|
|
"is_hidden": default_agent.is_hidden,
|
|
}
|
|
else:
|
|
conversation.agent = None
|
|
else:
|
|
agent_metadata = {
|
|
"slug": conversation.agent.slug,
|
|
"name": conversation.agent.name,
|
|
"is_creator": conversation.agent.creator == user,
|
|
"color": conversation.agent.style_color,
|
|
"icon": conversation.agent.style_icon,
|
|
"persona": conversation.agent.personality,
|
|
"is_hidden": conversation.agent.is_hidden,
|
|
}
|
|
|
|
meta_log = conversation.conversation_log
|
|
scrubbed_title = conversation.title if conversation.title else conversation.slug
|
|
|
|
if scrubbed_title:
|
|
scrubbed_title = scrubbed_title.replace("-", " ")
|
|
|
|
meta_log.update(
|
|
{
|
|
"conversation_id": conversation.id,
|
|
"slug": scrubbed_title,
|
|
"agent": agent_metadata,
|
|
"is_owner": conversation.source_owner == user,
|
|
}
|
|
)
|
|
|
|
if n:
|
|
# Get latest N messages if N > 0
|
|
if n > 0 and meta_log.get("chat"):
|
|
meta_log["chat"] = meta_log["chat"][-n:]
|
|
# Else return all messages except latest N
|
|
elif n < 0 and meta_log.get("chat"):
|
|
meta_log["chat"] = meta_log["chat"][:n]
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="get_shared_chat_history",
|
|
**common.__dict__,
|
|
)
|
|
|
|
return {"status": "ok", "response": meta_log}
|
|
|
|
|
|
@api_chat.delete("/history")
|
|
@requires(["authenticated"])
|
|
async def clear_chat_history(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
conversation_id: Optional[str] = None,
|
|
):
|
|
user = request.user.object
|
|
|
|
# Clear Conversation History
|
|
await ConversationAdapters.adelete_conversation_by_user(user, request.user.client_app, conversation_id)
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="clear_chat_history",
|
|
**common.__dict__,
|
|
)
|
|
|
|
return {"status": "ok", "message": "Conversation history cleared"}
|
|
|
|
|
|
@api_chat.post("/share/fork")
|
|
@requires(["authenticated"])
|
|
def fork_public_conversation(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
public_conversation_slug: str,
|
|
):
|
|
user = request.user.object
|
|
|
|
# Load Conversation History
|
|
public_conversation = PublicConversationAdapters.get_public_conversation_by_slug(public_conversation_slug)
|
|
|
|
# Duplicate Public Conversation to User's Private Conversation
|
|
new_conversation = ConversationAdapters.create_conversation_from_public_conversation(
|
|
user, public_conversation, request.user.client_app
|
|
)
|
|
|
|
chat_metadata = {"forked_conversation": public_conversation.slug}
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="fork_public_conversation",
|
|
**common.__dict__,
|
|
metadata=chat_metadata,
|
|
)
|
|
|
|
redirect_uri = str(request.app.url_path_for("chat_page"))
|
|
|
|
return Response(
|
|
status_code=200,
|
|
content=json.dumps(
|
|
{
|
|
"status": "ok",
|
|
"next_url": redirect_uri,
|
|
"conversation_id": str(new_conversation.id),
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
@api_chat.post("/share")
|
|
@requires(["authenticated"])
|
|
def duplicate_chat_history_public_conversation(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
conversation_id: str,
|
|
):
|
|
user = request.user.object
|
|
domain = request.headers.get("host")
|
|
scheme = request.url.scheme
|
|
|
|
# Throw unauthorized exception if domain not in ALLOWED_HOSTS
|
|
host_domain = domain.split(":")[0]
|
|
if host_domain not in ALLOWED_HOSTS:
|
|
raise HTTPException(status_code=401, detail="Unauthorized domain")
|
|
|
|
# Duplicate Conversation History to Public Conversation
|
|
conversation = ConversationAdapters.get_conversation_by_user(user, request.user.client_app, conversation_id)
|
|
public_conversation = ConversationAdapters.make_public_conversation_copy(conversation)
|
|
public_conversation_url = PublicConversationAdapters.get_public_conversation_url(public_conversation)
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="post_chat_share",
|
|
**common.__dict__,
|
|
)
|
|
|
|
return Response(
|
|
status_code=200, content=json.dumps({"status": "ok", "url": f"{scheme}://{domain}{public_conversation_url}"})
|
|
)
|
|
|
|
|
|
@api_chat.delete("/share")
|
|
@requires(["authenticated"])
|
|
def delete_public_conversation(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
public_conversation_slug: str,
|
|
):
|
|
user = request.user.object
|
|
|
|
# Delete Public Conversation
|
|
PublicConversationAdapters.delete_public_conversation_by_slug(user=user, slug=public_conversation_slug)
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="delete_chat_share",
|
|
**common.__dict__,
|
|
)
|
|
|
|
# Redirect to the main chat page
|
|
redirect_uri = str(request.app.url_path_for("chat_page"))
|
|
return RedirectResponse(
|
|
url=redirect_uri,
|
|
status_code=301,
|
|
)
|
|
|
|
|
|
@api_chat.get("/sessions")
|
|
@requires(["authenticated"])
|
|
def chat_sessions(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
recent: Optional[bool] = False,
|
|
):
|
|
user = request.user.object
|
|
|
|
# Load Conversation Sessions
|
|
conversations = ConversationAdapters.get_conversation_sessions(user, request.user.client_app)
|
|
if recent:
|
|
conversations = conversations[:8]
|
|
|
|
sessions = conversations.values_list(
|
|
"id",
|
|
"slug",
|
|
"title",
|
|
"agent__slug",
|
|
"agent__name",
|
|
"created_at",
|
|
"updated_at",
|
|
"agent__style_icon",
|
|
"agent__style_color",
|
|
"agent__is_hidden",
|
|
)
|
|
|
|
session_values = [
|
|
{
|
|
"conversation_id": str(session[0]),
|
|
"slug": session[2] or session[1],
|
|
"agent_name": session[4],
|
|
"created": session[5].strftime("%Y-%m-%d %H:%M:%S"),
|
|
"updated": session[6].strftime("%Y-%m-%d %H:%M:%S"),
|
|
"agent_icon": session[7],
|
|
"agent_color": session[8],
|
|
"agent_is_hidden": session[9],
|
|
}
|
|
for session in sessions
|
|
]
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="chat_sessions",
|
|
**common.__dict__,
|
|
)
|
|
|
|
return Response(content=json.dumps(session_values), media_type="application/json", status_code=200)
|
|
|
|
|
|
@api_chat.post("/sessions")
|
|
@requires(["authenticated"])
|
|
async def create_chat_session(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
agent_slug: Optional[str] = None,
|
|
# Add parameters here to create a custom hidden agent on the fly
|
|
):
|
|
user = request.user.object
|
|
|
|
# Create new Conversation Session
|
|
conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app, agent_slug)
|
|
|
|
response = {"conversation_id": str(conversation.id)}
|
|
|
|
conversation_metadata = {
|
|
"agent": agent_slug,
|
|
}
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="create_chat_sessions",
|
|
metadata=conversation_metadata,
|
|
**common.__dict__,
|
|
)
|
|
|
|
return Response(content=json.dumps(response), media_type="application/json", status_code=200)
|
|
|
|
|
|
@api_chat.get("/options", response_class=Response)
|
|
async def chat_options(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
) -> Response:
|
|
cmd_options = {}
|
|
for cmd in ConversationCommand:
|
|
if cmd in command_descriptions:
|
|
cmd_options[cmd.value] = command_descriptions[cmd]
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="chat_options",
|
|
**common.__dict__,
|
|
)
|
|
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
|
|
|
|
|
|
@api_chat.patch("/title", response_class=Response)
|
|
@requires(["authenticated"])
|
|
async def set_conversation_title(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
title: str,
|
|
conversation_id: Optional[str] = None,
|
|
) -> Response:
|
|
user = request.user.object
|
|
title = title.strip()[:200]
|
|
|
|
# Set Conversation Title
|
|
conversation = await ConversationAdapters.aset_conversation_title(
|
|
user, request.user.client_app, conversation_id, title
|
|
)
|
|
|
|
success = True if conversation else False
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="set_conversation_title",
|
|
**common.__dict__,
|
|
)
|
|
|
|
return Response(
|
|
content=json.dumps({"status": "ok", "success": success}), media_type="application/json", status_code=200
|
|
)
|
|
|
|
|
|
@api_chat.post("/title")
|
|
@requires(["authenticated"])
|
|
async def generate_chat_title(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
conversation_id: str,
|
|
):
|
|
user: KhojUser = request.user.object
|
|
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
|
|
|
|
# Conversation.title is explicitly set by the user. Do not override.
|
|
if conversation.title:
|
|
return {"status": "ok", "title": conversation.title}
|
|
|
|
if not conversation:
|
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
|
|
new_title = await acreate_title_from_history(request.user.object, conversation=conversation)
|
|
conversation.slug = new_title[:200]
|
|
|
|
await conversation.asave()
|
|
|
|
return {"status": "ok", "title": new_title}
|
|
|
|
|
|
@api_chat.delete("/conversation/message", response_class=Response)
|
|
@requires(["authenticated"])
|
|
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
|
|
user = request.user.object
|
|
success = ConversationAdapters.delete_message_by_turn_id(
|
|
user, delete_request.conversation_id, delete_request.turn_id
|
|
)
|
|
if success:
|
|
return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
|
|
else:
|
|
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
|
|
|
|
|
|
@api_chat.post("")
|
|
@requires(["authenticated"])
|
|
async def chat(
|
|
request: Request,
|
|
common: CommonQueryParams,
|
|
body: ChatRequestBody,
|
|
rate_limiter_per_minute=Depends(
|
|
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
|
),
|
|
rate_limiter_per_day=Depends(
|
|
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
|
),
|
|
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
|
):
|
|
# Access the parameters from the body
|
|
q = body.q
|
|
n = body.n
|
|
d = body.d
|
|
stream = body.stream
|
|
title = body.title
|
|
conversation_id = body.conversation_id
|
|
turn_id = str(body.turn_id or uuid.uuid4())
|
|
city = body.city
|
|
region = body.region
|
|
country = body.country or get_country_name_from_timezone(body.timezone)
|
|
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
|
timezone = body.timezone
|
|
raw_images = body.images
|
|
raw_query_files = body.files
|
|
|
|
async def event_generator(q: str, images: list[str]):
|
|
start_time = time.perf_counter()
|
|
ttft = None
|
|
chat_metadata: dict = {}
|
|
connection_alive = True
|
|
user: KhojUser = request.user.object
|
|
is_subscribed = has_required_scope(request, ["premium"])
|
|
event_delimiter = "␃🔚␗"
|
|
q = unquote(q)
|
|
train_of_thought = []
|
|
nonlocal conversation_id
|
|
nonlocal raw_query_files
|
|
|
|
tracer: dict = {
|
|
"mid": turn_id,
|
|
"cid": conversation_id,
|
|
"uid": user.id,
|
|
"khoj_version": state.khoj_version,
|
|
}
|
|
|
|
uploaded_images: list[str] = []
|
|
if images:
|
|
for image in images:
|
|
decoded_string = unquote(image)
|
|
base64_data = decoded_string.split(",", 1)[1]
|
|
image_bytes = base64.b64decode(base64_data)
|
|
webp_image_bytes = convert_image_to_webp(image_bytes)
|
|
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
|
|
if not uploaded_image:
|
|
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
|
uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
|
|
uploaded_images.append(uploaded_image)
|
|
|
|
query_files: Dict[str, str] = {}
|
|
if raw_query_files:
|
|
for file in raw_query_files:
|
|
query_files[file.name] = file.content
|
|
|
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
|
nonlocal connection_alive, ttft, train_of_thought
|
|
if not connection_alive or await request.is_disconnected():
|
|
connection_alive = False
|
|
logger.warning(f"User {user} disconnected from {common.client} client")
|
|
return
|
|
try:
|
|
if event_type == ChatEvent.END_LLM_RESPONSE:
|
|
collect_telemetry()
|
|
elif event_type == ChatEvent.START_LLM_RESPONSE:
|
|
ttft = time.perf_counter() - start_time
|
|
elif event_type == ChatEvent.STATUS:
|
|
train_of_thought.append({"type": event_type.value, "data": data})
|
|
elif event_type == ChatEvent.THOUGHT:
|
|
# Append the data to the last thought as thoughts are streamed
|
|
if (
|
|
len(train_of_thought) > 0
|
|
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
|
|
and type(train_of_thought[-1]["data"]) == type(data) == str
|
|
):
|
|
train_of_thought[-1]["data"] += data
|
|
else:
|
|
train_of_thought.append({"type": event_type.value, "data": data})
|
|
|
|
if event_type == ChatEvent.MESSAGE:
|
|
yield data
|
|
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
|
|
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
|
except asyncio.CancelledError as e:
|
|
connection_alive = False
|
|
logger.warn(f"User {user} disconnected from {common.client} client: {e}")
|
|
return
|
|
except Exception as e:
|
|
connection_alive = False
|
|
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
|
|
return
|
|
finally:
|
|
yield event_delimiter
|
|
|
|
async def send_llm_response(response: str, usage: dict = None):
|
|
# Send Chat Response
|
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
|
yield result
|
|
async for result in send_event(ChatEvent.MESSAGE, response):
|
|
yield result
|
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
|
yield result
|
|
# Send Usage Metadata once llm interactions are complete
|
|
if usage:
|
|
async for event in send_event(ChatEvent.USAGE, usage):
|
|
yield event
|
|
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
|
yield result
|
|
|
|
def collect_telemetry():
|
|
# Gather chat response telemetry
|
|
nonlocal chat_metadata
|
|
latency = time.perf_counter() - start_time
|
|
cmd_set = set([cmd.value for cmd in conversation_commands])
|
|
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
|
|
chat_metadata = chat_metadata or {}
|
|
chat_metadata["conversation_command"] = cmd_set
|
|
chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
|
|
chat_metadata["latency"] = f"{latency:.3f}"
|
|
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
|
chat_metadata["cost"] = f"{cost:.5f}"
|
|
|
|
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
|
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
|
logger.info(f"Chat response cost: ${cost:.5f}")
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="chat",
|
|
client=common.client,
|
|
user_agent=request.headers.get("user-agent"),
|
|
host=request.headers.get("host"),
|
|
metadata=chat_metadata,
|
|
)
|
|
|
|
if is_query_empty(q):
|
|
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
|
|
yield result
|
|
return
|
|
|
|
# Automated tasks are handled before to allow mixing them with other conversation commands
|
|
cmds_to_rate_limit = []
|
|
is_automated_task = False
|
|
if q.startswith("/automated_task"):
|
|
is_automated_task = True
|
|
q = q.replace("/automated_task", "").lstrip()
|
|
cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
|
|
|
|
# Extract conversation command from query
|
|
conversation_commands = [get_conversation_command(query=q)]
|
|
|
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
|
user,
|
|
client_application=request.user.client_app,
|
|
conversation_id=conversation_id,
|
|
title=title,
|
|
create_new=body.create_new,
|
|
)
|
|
if not conversation:
|
|
async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
|
|
yield result
|
|
return
|
|
conversation_id = conversation.id
|
|
|
|
async for event in send_event(ChatEvent.METADATA, {"conversationId": str(conversation_id), "turnId": turn_id}):
|
|
yield event
|
|
|
|
agent: Agent | None = None
|
|
default_agent = await AgentAdapters.aget_default_agent()
|
|
if conversation.agent and conversation.agent != default_agent:
|
|
agent = conversation.agent
|
|
|
|
if not conversation.agent:
|
|
conversation.agent = default_agent
|
|
await conversation.asave()
|
|
agent = default_agent
|
|
|
|
await is_ready_to_chat(user)
|
|
user_name = await aget_user_name(user)
|
|
location = None
|
|
if city or region or country or country_code:
|
|
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
meta_log = conversation.conversation_log
|
|
|
|
researched_results = ""
|
|
online_results: Dict = dict()
|
|
code_results: Dict = dict()
|
|
generated_asset_results: Dict = dict()
|
|
## Extract Document References
|
|
compiled_references: List[Any] = []
|
|
inferred_queries: List[Any] = []
|
|
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
|
attached_file_context = gather_raw_query_files(query_files)
|
|
|
|
generated_images: List[str] = []
|
|
generated_files: List[FileAttachment] = []
|
|
generated_mermaidjs_diagram: str = None
|
|
program_execution_context: List[str] = []
|
|
|
|
if conversation_commands == [ConversationCommand.Default]:
|
|
try:
|
|
chosen_io = await aget_data_sources_and_output_format(
|
|
q,
|
|
meta_log,
|
|
is_automated_task,
|
|
user=user,
|
|
query_images=uploaded_images,
|
|
agent=agent,
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
)
|
|
except ValueError as e:
|
|
logger.error(f"Error getting data sources and output format: {e}. Falling back to default.")
|
|
conversation_commands = [ConversationCommand.General]
|
|
|
|
conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")]
|
|
|
|
# If we're doing research, we don't want to do anything else
|
|
if ConversationCommand.Research in conversation_commands:
|
|
conversation_commands = [ConversationCommand.Research]
|
|
|
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
|
async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
|
|
yield result
|
|
|
|
cmds_to_rate_limit += conversation_commands
|
|
for cmd in cmds_to_rate_limit:
|
|
try:
|
|
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
|
q = q.replace(f"/{cmd.value}", "").strip()
|
|
except HTTPException as e:
|
|
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
|
|
yield result
|
|
return
|
|
|
|
defiltered_query = defilter_query(q)
|
|
|
|
if conversation_commands == [ConversationCommand.Research]:
|
|
async for research_result in execute_information_collection(
|
|
user=user,
|
|
query=defiltered_query,
|
|
conversation_id=conversation_id,
|
|
conversation_history=meta_log,
|
|
query_images=uploaded_images,
|
|
agent=agent,
|
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
|
user_name=user_name,
|
|
location=location,
|
|
file_filters=conversation.file_filters if conversation else [],
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
):
|
|
if isinstance(research_result, InformationCollectionIteration):
|
|
if research_result.summarizedResult:
|
|
if research_result.onlineContext:
|
|
online_results.update(research_result.onlineContext)
|
|
if research_result.codeContext:
|
|
code_results.update(research_result.codeContext)
|
|
if research_result.context:
|
|
compiled_references.extend(research_result.context)
|
|
|
|
researched_results += research_result.summarizedResult
|
|
|
|
else:
|
|
yield research_result
|
|
|
|
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
|
if state.verbose > 1:
|
|
logger.debug(f"Researched Results: {researched_results}")
|
|
|
|
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
|
file_filters = conversation.file_filters if conversation else []
|
|
# Skip trying to summarize if
|
|
if (
|
|
# summarization intent was inferred
|
|
ConversationCommand.Summarize in conversation_commands
|
|
# and not triggered via slash command
|
|
and not used_slash_summarize
|
|
# but we can't actually summarize
|
|
and len(file_filters) == 0
|
|
):
|
|
conversation_commands.remove(ConversationCommand.Summarize)
|
|
elif ConversationCommand.Summarize in conversation_commands:
|
|
response_log = ""
|
|
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
|
if len(file_filters) == 0 and not agent_has_entries:
|
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
|
async for result in send_llm_response(response_log, tracer.get("usage")):
|
|
yield result
|
|
else:
|
|
async for response in generate_summary_from_files(
|
|
q=q,
|
|
user=user,
|
|
file_filters=file_filters,
|
|
meta_log=meta_log,
|
|
query_images=uploaded_images,
|
|
agent=agent,
|
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
):
|
|
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
|
yield response[ChatEvent.STATUS]
|
|
else:
|
|
if isinstance(response, str):
|
|
response_log = response
|
|
async for result in send_llm_response(response, tracer.get("usage")):
|
|
yield result
|
|
|
|
summarized_document = FileAttachment(
|
|
name="Summarized Document",
|
|
content=response_log,
|
|
type="text/plain",
|
|
size=len(response_log.encode("utf-8")),
|
|
)
|
|
|
|
async for result in send_event(ChatEvent.GENERATED_ASSETS, {"files": [summarized_document.model_dump()]}):
|
|
yield result
|
|
|
|
generated_files.append(summarized_document)
|
|
|
|
custom_filters = []
|
|
if conversation_commands == [ConversationCommand.Help]:
|
|
if not q:
|
|
chat_model = await ConversationAdapters.aget_user_chat_model(user)
|
|
if chat_model == None:
|
|
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
|
model_type = chat_model.model_type
|
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
|
async for result in send_llm_response(formatted_help, tracer.get("usage")):
|
|
yield result
|
|
return
|
|
# Adding specification to search online specifically on khoj.dev pages.
|
|
custom_filters.append("site:khoj.dev")
|
|
conversation_commands.append(ConversationCommand.Online)
|
|
|
|
if ConversationCommand.Automation in conversation_commands:
|
|
try:
|
|
automation, crontime, query_to_run, subject = await create_automation(
|
|
q, timezone, user, request.url, meta_log, tracer=tracer
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
|
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
|
async for result in send_llm_response(error_message, tracer.get("usage")):
|
|
yield result
|
|
return
|
|
|
|
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
|
# Trigger task to save conversation to DB
|
|
asyncio.create_task(
|
|
save_to_conversation_log(
|
|
q,
|
|
llm_response,
|
|
user,
|
|
meta_log,
|
|
user_message_time,
|
|
intent_type="automation",
|
|
client_application=request.user.client_app,
|
|
conversation_id=conversation_id,
|
|
inferred_queries=[query_to_run],
|
|
automation_id=automation.id,
|
|
query_images=uploaded_images,
|
|
train_of_thought=train_of_thought,
|
|
raw_query_files=raw_query_files,
|
|
tracer=tracer,
|
|
)
|
|
)
|
|
# Send LLM Response
|
|
async for result in send_llm_response(llm_response, tracer.get("usage")):
|
|
yield result
|
|
return
|
|
|
|
# Gather Context
|
|
## Extract Document References
|
|
if not ConversationCommand.Research in conversation_commands:
|
|
try:
|
|
async for result in extract_references_and_questions(
|
|
user,
|
|
meta_log,
|
|
q,
|
|
(n or 7),
|
|
d,
|
|
conversation_id,
|
|
conversation_commands,
|
|
location,
|
|
partial(send_event, ChatEvent.STATUS),
|
|
query_images=uploaded_images,
|
|
agent=agent,
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
):
|
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
yield result[ChatEvent.STATUS]
|
|
else:
|
|
compiled_references.extend(result[0])
|
|
inferred_queries.extend(result[1])
|
|
defiltered_query = result[2]
|
|
except Exception as e:
|
|
error_message = (
|
|
f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
|
)
|
|
logger.error(error_message, exc_info=True)
|
|
async for result in send_event(
|
|
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
|
):
|
|
yield result
|
|
|
|
if not is_none_or_empty(compiled_references):
|
|
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
|
# Strip only leading # from headings
|
|
headings = headings.replace("#", "")
|
|
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
|
yield result
|
|
|
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
|
async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
|
|
yield result
|
|
return
|
|
|
|
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
|
conversation_commands.remove(ConversationCommand.Notes)
|
|
|
|
## Gather Online References
|
|
if ConversationCommand.Online in conversation_commands:
|
|
try:
|
|
async for result in search_online(
|
|
defiltered_query,
|
|
meta_log,
|
|
location,
|
|
user,
|
|
partial(send_event, ChatEvent.STATUS),
|
|
custom_filters,
|
|
query_images=uploaded_images,
|
|
agent=agent,
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
):
|
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
yield result[ChatEvent.STATUS]
|
|
else:
|
|
online_results = result
|
|
except Exception as e:
|
|
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
|
logger.warning(error_message)
|
|
async for result in send_event(
|
|
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
|
|
):
|
|
yield result
|
|
|
|
## Gather Webpage References
|
|
if ConversationCommand.Webpage in conversation_commands:
|
|
try:
|
|
async for result in read_webpages(
|
|
defiltered_query,
|
|
meta_log,
|
|
location,
|
|
user,
|
|
partial(send_event, ChatEvent.STATUS),
|
|
max_webpages_to_read=1,
|
|
query_images=uploaded_images,
|
|
agent=agent,
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
):
|
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
yield result[ChatEvent.STATUS]
|
|
else:
|
|
direct_web_pages = result
|
|
webpages = []
|
|
for query in direct_web_pages:
|
|
if online_results.get(query):
|
|
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
|
else:
|
|
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
|
|
|
for webpage in direct_web_pages[query]["webpages"]:
|
|
webpages.append(webpage["link"])
|
|
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
|
yield result
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error reading webpages: {e}. Attempting to respond without webpage results",
|
|
exc_info=True,
|
|
)
|
|
async for result in send_event(
|
|
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
|
|
):
|
|
yield result
|
|
|
|
## Gather Code Results
|
|
if ConversationCommand.Code in conversation_commands:
|
|
try:
|
|
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
|
|
async for result in run_code(
|
|
defiltered_query,
|
|
meta_log,
|
|
context,
|
|
location,
|
|
user,
|
|
partial(send_event, ChatEvent.STATUS),
|
|
query_images=uploaded_images,
|
|
agent=agent,
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
):
|
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
yield result[ChatEvent.STATUS]
|
|
else:
|
|
code_results = result
|
|
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
|
|
yield result
|
|
except ValueError as e:
|
|
program_execution_context.append(f"Failed to run code")
|
|
logger.warning(
|
|
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
|
exc_info=True,
|
|
)
|
|
|
|
## Send Gathered References
|
|
unique_online_results = deduplicate_organic_results(online_results)
|
|
async for result in send_event(
|
|
ChatEvent.REFERENCES,
|
|
{
|
|
"inferredQueries": inferred_queries,
|
|
"context": compiled_references,
|
|
"onlineContext": unique_online_results,
|
|
"codeContext": code_results,
|
|
},
|
|
):
|
|
yield result
|
|
|
|
# Generate Output
|
|
## Generate Image Output
|
|
if ConversationCommand.Image in conversation_commands:
|
|
async for result in text_to_image(
|
|
defiltered_query,
|
|
user,
|
|
meta_log,
|
|
location_data=location,
|
|
references=compiled_references,
|
|
online_results=online_results,
|
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
|
query_images=uploaded_images,
|
|
agent=agent,
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
):
|
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
yield result[ChatEvent.STATUS]
|
|
else:
|
|
generated_image, status_code, improved_image_prompt = result
|
|
|
|
inferred_queries.append(improved_image_prompt)
|
|
if generated_image is None or status_code != 200:
|
|
program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
|
|
async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
|
|
yield result
|
|
else:
|
|
generated_images.append(generated_image)
|
|
|
|
generated_asset_results["images"] = {
|
|
"query": improved_image_prompt,
|
|
}
|
|
|
|
async for result in send_event(
|
|
ChatEvent.GENERATED_ASSETS,
|
|
{
|
|
"images": [generated_image],
|
|
},
|
|
):
|
|
yield result
|
|
|
|
if ConversationCommand.Diagram in conversation_commands:
|
|
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
|
yield result
|
|
|
|
inferred_queries = []
|
|
diagram_description = ""
|
|
|
|
async for result in generate_mermaidjs_diagram(
|
|
q=defiltered_query,
|
|
conversation_history=meta_log,
|
|
location_data=location,
|
|
note_references=compiled_references,
|
|
online_results=online_results,
|
|
query_images=uploaded_images,
|
|
user=user,
|
|
agent=agent,
|
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
|
query_files=attached_file_context,
|
|
tracer=tracer,
|
|
):
|
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
yield result[ChatEvent.STATUS]
|
|
else:
|
|
better_diagram_description_prompt, mermaidjs_diagram_description = result
|
|
if better_diagram_description_prompt and mermaidjs_diagram_description:
|
|
inferred_queries.append(better_diagram_description_prompt)
|
|
diagram_description = mermaidjs_diagram_description
|
|
|
|
generated_mermaidjs_diagram = diagram_description
|
|
|
|
generated_asset_results["diagrams"] = {
|
|
"query": better_diagram_description_prompt,
|
|
}
|
|
|
|
async for result in send_event(
|
|
ChatEvent.GENERATED_ASSETS,
|
|
{
|
|
"mermaidjsDiagram": mermaidjs_diagram_description,
|
|
},
|
|
):
|
|
yield result
|
|
else:
|
|
error_message = "Failed to generate diagram. Please try again later."
|
|
program_execution_context.append(
|
|
prompts.failed_diagram_generation.format(
|
|
attempted_diagram=better_diagram_description_prompt
|
|
)
|
|
)
|
|
|
|
async for result in send_event(ChatEvent.STATUS, error_message):
|
|
yield result
|
|
|
|
## Generate Text Output
|
|
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
|
yield result
|
|
|
|
llm_response, chat_metadata = await agenerate_chat_response(
|
|
defiltered_query,
|
|
meta_log,
|
|
conversation,
|
|
compiled_references,
|
|
online_results,
|
|
code_results,
|
|
inferred_queries,
|
|
conversation_commands,
|
|
user,
|
|
request.user.client_app,
|
|
conversation_id,
|
|
location,
|
|
user_name,
|
|
researched_results,
|
|
uploaded_images,
|
|
train_of_thought,
|
|
attached_file_context,
|
|
raw_query_files,
|
|
generated_images,
|
|
generated_files,
|
|
generated_mermaidjs_diagram,
|
|
program_execution_context,
|
|
generated_asset_results,
|
|
is_subscribed,
|
|
tracer,
|
|
)
|
|
|
|
continue_stream = True
|
|
async for item in llm_response:
|
|
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
|
if item is None:
|
|
continue
|
|
if not connection_alive or not continue_stream:
|
|
# Drain the generator if disconnected but keep processing internally
|
|
continue
|
|
message = item.response if isinstance(item, ResponseWithThought) else item
|
|
if isinstance(item, ResponseWithThought) and item.thought:
|
|
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
|
yield result
|
|
continue
|
|
|
|
# Start sending response
|
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
|
yield result
|
|
|
|
try:
|
|
async for result in send_event(ChatEvent.MESSAGE, message):
|
|
yield result
|
|
except Exception as e:
|
|
continue_stream = False
|
|
logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
|
|
|
|
# Signal end of LLM response after the loop finishes
|
|
if connection_alive:
|
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
|
yield result
|
|
# Send Usage Metadata once llm interactions are complete
|
|
if tracer.get("usage"):
|
|
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
|
yield event
|
|
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
|
yield result
|
|
logger.debug("Finished streaming response")
|
|
|
|
## Stream Text Response
|
|
if stream:
|
|
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
|
## Non-Streaming Text Response
|
|
else:
|
|
response_iterator = event_generator(q, images=raw_images)
|
|
response_data = await read_chat_stream(response_iterator)
|
|
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|