Simplify advanced streaming chat API, align params with normal chat API

This commit is contained in:
Debanjum Singh Solanky
2024-07-22 17:09:41 +05:30
parent b8d3e3669a
commit 6b9550238f

View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import json import json
import logging import logging
import math
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -529,29 +528,47 @@ async def set_conversation_title(
@api_chat.get("/stream") @api_chat.get("/stream")
async def stream_chat( async def stream_chat(
request: Request, request: Request,
common: CommonQueryParams,
q: str, q: str,
conversation_id: int, n: int = 7,
d: float = 0.18,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None, city: Optional[str] = None,
region: Optional[str] = None, region: Optional[str] = None,
country: Optional[str] = None, country: Optional[str] = None,
timezone: Optional[str] = None, timezone: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
): ):
async def event_generator(q: str): async def event_generator(q: str):
connection_alive = True connection_alive = True
user: KhojUser = request.user.object
q = unquote(q)
async def send_event(event_type: str, data: str): async def send_event(event_type: str, data: str):
nonlocal connection_alive nonlocal connection_alive
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
connection_alive = False connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return return
try: try:
if event_type == "message": if event_type == "message":
yield data yield data
else: else:
yield json.dumps({"type": event_type, "data": data}) yield json.dumps({"type": event_type, "data": data})
except asyncio.CancelledError:
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return
except Exception as e: except Exception as e:
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
return
async def send_llm_response(response: str): async def send_llm_response(response: str):
async for result in send_event("start_llm_response", ""): async for result in send_event("start_llm_response", ""):
@@ -561,43 +578,23 @@ async def stream_chat(
async for result in send_event("end_llm_response", ""): async for result in send_event("end_llm_response", ""):
yield result yield result
user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
) )
if not conversation:
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"):
yield result
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
await is_ready_to_chat(user) await is_ready_to_chat(user)
user_name = await aget_user_name(user) user_name = await aget_user_name(user)
location = None location = None
if city or region or country: if city or region or country:
location = LocationData(city=city, region=region, country=country) location = LocationData(city=city, region=region, country=country)
while connection_alive:
try:
if conversation:
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
# Refresh these because the connection to the database might have been closed
await conversation.arefresh_from_db()
try:
await sync_to_async(hourly_limiter)(request)
await sync_to_async(daily_limiter)(request)
except HTTPException as e:
async for result in send_event("rate_limit", e.detail):
yield result
return
if is_query_empty(q): if is_query_empty(q):
async for event in send_llm_response("Please ask your query to get started."): async for result in send_llm_response("Please ask your query to get started."):
yield event yield result
return return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -609,8 +606,6 @@ async def stream_chat(
meta_log = conversation.conversation_log meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
@@ -629,6 +624,7 @@ async def stream_chat(
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip() q = q.replace(f"/{cmd.value}", "").strip()
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else [] file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if # Skip trying to summarize if
if ( if (
@@ -643,9 +639,7 @@ async def stream_chat(
elif ConversationCommand.Summarize in conversation_commands: elif ConversationCommand.Summarize in conversation_commands:
response_log = "" response_log = ""
if len(file_filters) == 0: if len(file_filters) == 0:
response_log = ( response_log = "No files selected for summarization. Please add files using the section on the left."
"No files selected for summarization. Please add files using the section on the left."
)
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
elif len(file_filters) > 1: elif len(file_filters) > 1:
@@ -702,12 +696,11 @@ async def stream_chat(
if conversation_config == None: if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type model_type = conversation_config.model_type
formatted_help = help_message.format( formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
model=model_type, version=state.khoj_version, device=get_device()
)
async for result in send_llm_response(formatted_help): async for result in send_llm_response(formatted_help):
yield result yield result
return return
# Adding specification to search online specifically on khoj.dev pages.
custom_filters.append("site:khoj.dev") custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online) conversation_commands.append(ConversationCommand.Online)
@@ -756,8 +749,8 @@ async def stream_chat(
request, request,
meta_log, meta_log,
q, q,
7, (n or 7),
0.18, (d or 0.18),
conversation_id, conversation_id,
conversation_commands, conversation_commands,
location, location,
@@ -771,19 +764,15 @@ async def stream_chat(
defiltered_query = result[2] defiltered_query = result[2]
if not is_none_or_empty(compiled_references): if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join( headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
set([c.get("compiled", c).split("\n")[0] for c in compiled_references])
)
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
yield result yield result
online_results: Dict = dict() online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
user
):
async for result in send_llm_response(f"{no_entries_found.format()}"): async for result in send_llm_response(f"{no_entries_found.format()}"):
yield event yield result
return return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
@@ -807,9 +796,7 @@ async def stream_chat(
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:
async for result in read_webpages( async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
defiltered_query, meta_log, location, partial(send_event, "status")
):
if isinstance(result, dict) and "status" in result: if isinstance(result, dict) and "status" in result:
yield result["status"] yield result["status"]
else: else:
@@ -932,22 +919,17 @@ async def stream_chat(
return return
if not connection_alive or not continue_stream: if not connection_alive or not continue_stream:
continue continue
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
item = item.split("### compiled references:")[0]
try: try:
async for result in send_event("message", f"{item}"): async for result in send_event("message", f"{item}"):
yield result yield result
except Exception as e: except Exception as e:
continue_stream = False continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
except asyncio.CancelledError:
logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True)
return
except Exception as e:
logger.error(f"General Error in API endpoint: {e}", exc_info=True)
return
return StreamingResponse(event_generator(q), media_type="text/plain") return StreamingResponse(event_generator(q), media_type="text/plain")