Simplify advanced streaming chat API, align params with normal chat API

This commit is contained in:
Debanjum Singh Solanky
2024-07-22 17:09:41 +05:30
parent b8d3e3669a
commit 6b9550238f

View File

@@ -1,7 +1,6 @@
import asyncio
import json
import logging
import math
from datetime import datetime
from functools import partial
from typing import Any, Dict, List, Optional
@@ -529,29 +528,47 @@ async def set_conversation_title(
@api_chat.get("/stream")
async def stream_chat(
request: Request,
common: CommonQueryParams,
q: str,
conversation_id: int,
n: int = 7,
d: float = 0.18,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
):
async def event_generator(q: str):
connection_alive = True
user: KhojUser = request.user.object
q = unquote(q)
async def send_event(event_type: str, data: str):
nonlocal connection_alive
if not connection_alive or await request.is_disconnected():
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return
try:
if event_type == "message":
yield data
else:
yield json.dumps({"type": event_type, "data": data})
except asyncio.CancelledError:
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return
except Exception as e:
connection_alive = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
return
async def send_llm_response(response: str):
async for result in send_event("start_llm_response", ""):
@@ -561,393 +578,358 @@ async def stream_chat(
async for result in send_event("end_llm_response", ""):
yield result
user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
)
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
if not conversation:
async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"):
yield result
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
while connection_alive:
try:
if conversation:
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
# Refresh these because the connection to the database might have been closed
await conversation.arefresh_from_db()
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)]
try:
await sync_to_async(hourly_limiter)(request)
await sync_to_async(daily_limiter)(request)
except HTTPException as e:
async for result in send_event("rate_limit", e.detail):
yield result
return
async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
yield result
if is_query_empty(q):
async for event in send_llm_response("Please ask your query to get started."):
yield event
return
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result
async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log):
yield result
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
async for result in send_llm_response(response_log):
yield result
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
"status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
response = await extract_relevant_summary(q, contextual_data)
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = (
"No files selected for summarization. Please add files using the section on the left."
)
async for result in send_llm_response(response_log):
yield result
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
async for result in send_llm_response(response_log):
yield result
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
"status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(q, contextual_data)
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log):
yield result
await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
return
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(
model=model_type, version=state.khoj_version, device=get_device()
)
async for result in send_llm_response(formatted_help):
yield result
return
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_llm_response(error_message):
yield result
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
common = CommonQueryParamsClass(
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
**common.__dict__,
)
async for result in send_llm_response(llm_response):
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log):
yield result
return
await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
return
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
meta_log,
q,
7,
0.18,
conversation_id,
conversation_commands,
location,
partial(send_event, "status"),
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help):
yield result
return
# Adding specification to search online specifically on khoj.dev pages.
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_llm_response(error_message):
yield result
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
common = CommonQueryParamsClass(
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
**common.__dict__,
)
async for result in send_llm_response(llm_response):
yield result
return
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
(d or 0.18),
conversation_id,
conversation_commands,
location,
partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(
set([c.get("compiled", c).split("\n")[0] for c in compiled_references])
)
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
user
):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield event
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
online_results = result
except ValueError as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
async for result in send_llm_response(error_message):
yield result
return
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, "status")
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
yield result
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results",
exc_info=True,
)
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
async for result in text_to_image(
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"context": compiled_references,
"online_results": online_results,
"inferredQueries": [improved_image_prompt],
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
async for result in send_event(
"references", json.dumps({"context": compiled_references, "online_results": online_results})
):
online_results = result
except ValueError as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
async for result in send_llm_response(error_message):
yield result
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
request.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
async for result in send_event("start_llm_response", ""):
yield result
continue_stream = True
async for item in iterator:
if item is None:
async for result in send_event("end_llm_response", ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:
continue
try:
async for result in send_event("message", f"{item}"):
yield result
except Exception as e:
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
except asyncio.CancelledError:
logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True)
return
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
yield result
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results",
exc_info=True,
)
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
async for result in text_to_image(
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"context": compiled_references,
"online_results": online_results,
"inferredQueries": [improved_image_prompt],
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
async for result in send_event(
"references", json.dumps({"context": compiled_references, "online_results": online_results})
):
yield result
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
request.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
async for result in send_event("start_llm_response", ""):
yield result
continue_stream = True
async for item in iterator:
if item is None:
async for result in send_event("end_llm_response", ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:
continue
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
item = item.split("### compiled references:")[0]
try:
async for result in send_event("message", f"{item}"):
yield result
except Exception as e:
logger.error(f"General Error in API endpoint: {e}", exc_info=True)
return
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
return StreamingResponse(event_generator(q), media_type="text/plain")