Simplify advanced streaming chat API, align params with normal chat API

This commit is contained in:
Debanjum Singh Solanky
2024-07-22 17:09:41 +05:30
parent b8d3e3669a
commit 6b9550238f

View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import json import json
import logging import logging
import math
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -529,29 +528,47 @@ async def set_conversation_title(
@api_chat.get("/stream") @api_chat.get("/stream")
async def stream_chat( async def stream_chat(
request: Request, request: Request,
common: CommonQueryParams,
q: str, q: str,
conversation_id: int, n: int = 7,
d: float = 0.18,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None, city: Optional[str] = None,
region: Optional[str] = None, region: Optional[str] = None,
country: Optional[str] = None, country: Optional[str] = None,
timezone: Optional[str] = None, timezone: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
): ):
async def event_generator(q: str): async def event_generator(q: str):
connection_alive = True connection_alive = True
user: KhojUser = request.user.object
q = unquote(q)
async def send_event(event_type: str, data: str): async def send_event(event_type: str, data: str):
nonlocal connection_alive nonlocal connection_alive
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
connection_alive = False connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return return
try: try:
if event_type == "message": if event_type == "message":
yield data yield data
else: else:
yield json.dumps({"type": event_type, "data": data}) yield json.dumps({"type": event_type, "data": data})
except asyncio.CancelledError:
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return
except Exception as e: except Exception as e:
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
return
async def send_llm_response(response: str): async def send_llm_response(response: str):
async for result in send_event("start_llm_response", ""): async for result in send_event("start_llm_response", ""):
@@ -561,393 +578,358 @@ async def stream_chat(
async for result in send_event("end_llm_response", ""): async for result in send_event("end_llm_response", ""):
yield result yield result
user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
) )
if not conversation:
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"):
yield result
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
await is_ready_to_chat(user) await is_ready_to_chat(user)
user_name = await aget_user_name(user) user_name = await aget_user_name(user)
location = None location = None
if city or region or country: if city or region or country:
location = LocationData(city=city, region=region, country=country) location = LocationData(city=city, region=region, country=country)
while connection_alive: if is_query_empty(q):
try: async for result in send_llm_response("Please ask your query to get started."):
if conversation: yield result
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) return
# Refresh these because the connection to the database might have been closed user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
await conversation.arefresh_from_db() conversation_commands = [get_conversation_command(query=q, any_references=True)]
try: async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
await sync_to_async(hourly_limiter)(request) yield result
await sync_to_async(daily_limiter)(request)
except HTTPException as e:
async for result in send_event("rate_limit", e.detail):
yield result
return
if is_query_empty(q): meta_log = conversation.conversation_log
async for event in send_llm_response("Please ask your query to get started."): is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
yield event
return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result
async for result in send_event("status", f"**👀 Understanding Query**: {q}"): mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log):
yield result yield result
elif len(file_filters) > 1:
meta_log = conversation.conversation_log response_log = "Only one file can be selected for summarization."
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] async for result in send_llm_response(response_log):
yield result
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] else:
try:
if conversation_commands == [ConversationCommand.Default] or is_automated_task: file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) if len(file_object) == 0:
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event( async for result in send_event(
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
): ):
yield result yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) response = await extract_relevant_summary(q, contextual_data)
async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): response_log = str(response)
async for result in send_llm_response(response_log):
yield result yield result
if mode not in conversation_commands: except Exception as e:
conversation_commands.append(mode) response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
for cmd in conversation_commands: async for result in send_llm_response(response_log):
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = (
"No files selected for summarization. Please add files using the section on the left."
)
async for result in send_llm_response(response_log):
yield result
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
async for result in send_llm_response(response_log):
yield result
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
"status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(q, contextual_data)
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log):
yield result
await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
return
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(
model=model_type, version=state.khoj_version, device=get_device()
)
async for result in send_llm_response(formatted_help):
yield result
return
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_llm_response(error_message):
yield result
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
common = CommonQueryParamsClass(
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
**common.__dict__,
)
async for result in send_llm_response(llm_response):
yield result yield result
return await sync_to_async(save_to_conversation_log)(
q,
response_log,
user,
meta_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
return
compiled_references, inferred_queries, defiltered_query = [], [], None custom_filters = []
async for result in extract_references_and_questions( if conversation_commands == [ConversationCommand.Help]:
request, if not q:
meta_log, conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
q, if conversation_config == None:
7, conversation_config = await ConversationAdapters.aget_default_conversation_config()
0.18, model_type = conversation_config.model_type
conversation_id, formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
conversation_commands, async for result in send_llm_response(formatted_help):
location, yield result
partial(send_event, "status"), return
# Adding specification to search online specifically on khoj.dev pages.
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_llm_response(error_message):
yield result
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
common = CommonQueryParamsClass(
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
**common.__dict__,
)
async for result in send_llm_response(llm_response):
yield result
return
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
(d or 0.18),
conversation_id,
conversation_commands,
location,
partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
): ):
if isinstance(result, dict) and "status" in result: if isinstance(result, dict) and "status" in result:
yield result["status"] yield result["status"]
else: else:
compiled_references.extend(result[0]) online_results = result
inferred_queries.extend(result[1]) except ValueError as e:
defiltered_query = result[2] error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
if not is_none_or_empty(compiled_references): async for result in send_llm_response(error_message):
headings = "\n- " + "\n- ".join(
set([c.get("compiled", c).split("\n")[0] for c in compiled_references])
)
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
user
):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield event
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
online_results = result
except ValueError as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
async for result in send_llm_response(error_message):
yield result
return
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, "status")
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
yield result
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results",
exc_info=True,
)
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
async for result in text_to_image(
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"context": compiled_references,
"online_results": online_results,
"inferredQueries": [improved_image_prompt],
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
async for result in send_event(
"references", json.dumps({"context": compiled_references, "online_results": online_results})
):
yield result yield result
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
request.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
async for result in send_event("start_llm_response", ""):
yield result
continue_stream = True
async for item in iterator:
if item is None:
async for result in send_event("end_llm_response", ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:
continue
try:
async for result in send_event("message", f"{item}"):
yield result
except Exception as e:
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
except asyncio.CancelledError:
logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True)
return return
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
yield result
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results",
exc_info=True,
)
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
async for result in text_to_image(
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"context": compiled_references,
"online_results": online_results,
"inferredQueries": [improved_image_prompt],
"image": image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
async for result in send_event(
"references", json.dumps({"context": compiled_references, "online_results": online_results})
):
yield result
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
request.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
async for result in send_event("start_llm_response", ""):
yield result
continue_stream = True
async for item in iterator:
if item is None:
async for result in send_event("end_llm_response", ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:
continue
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
item = item.split("### compiled references:")[0]
try:
async for result in send_event("message", f"{item}"):
yield result
except Exception as e: except Exception as e:
logger.error(f"General Error in API endpoint: {e}", exc_info=True) continue_stream = False
return logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
return StreamingResponse(event_generator(q), media_type="text/plain") return StreamingResponse(event_generator(q), media_type="text/plain")