Simplify advanced streaming chat API, align params with normal chat API

This commit is contained in:
Debanjum Singh Solanky
2024-07-22 17:09:41 +05:30
parent b8d3e3669a
commit 6b9550238f

View File

@@ -1,7 +1,6 @@
import asyncio
import json
import logging
import math
from datetime import datetime
from functools import partial
from typing import Any, Dict, List, Optional
@@ -529,29 +528,47 @@ async def set_conversation_title(
@api_chat.get("/stream")
async def stream_chat(
request: Request,
common: CommonQueryParams,
q: str,
conversation_id: int,
n: int = 7,
d: float = 0.18,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
):
async def event_generator(q: str):
connection_alive = True
user: KhojUser = request.user.object
q = unquote(q)
async def send_event(event_type: str, data: str):
nonlocal connection_alive
if not connection_alive or await request.is_disconnected():
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return
try:
if event_type == "message":
yield data
else:
yield json.dumps({"type": event_type, "data": data})
except asyncio.CancelledError:
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
return
except Exception as e:
connection_alive = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
return
async def send_llm_response(response: str):
async for result in send_event("start_llm_response", ""):
@@ -561,43 +578,23 @@ async def stream_chat(
async for result in send_event("end_llm_response", ""):
yield result
user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
)
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
if not conversation:
async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"):
yield result
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
while connection_alive:
try:
if conversation:
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
# Refresh these because the connection to the database might have been closed
await conversation.arefresh_from_db()
try:
await sync_to_async(hourly_limiter)(request)
await sync_to_async(daily_limiter)(request)
except HTTPException as e:
async for result in send_event("rate_limit", e.detail):
yield result
return
if is_query_empty(q):
async for event in send_llm_response("Please ask your query to get started."):
yield event
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -609,8 +606,6 @@ async def stream_chat(
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
@@ -629,6 +624,7 @@ async def stream_chat(
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
@@ -643,9 +639,7 @@ async def stream_chat(
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = (
"No files selected for summarization. Please add files using the section on the left."
)
response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log):
yield result
elif len(file_filters) > 1:
@@ -702,12 +696,11 @@ async def stream_chat(
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(
model=model_type, version=state.khoj_version, device=get_device()
)
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help):
yield result
return
# Adding specification to search online specifically on khoj.dev pages.
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
@@ -756,8 +749,8 @@ async def stream_chat(
request,
meta_log,
q,
7,
0.18,
(n or 7),
(d or 0.18),
conversation_id,
conversation_commands,
location,
@@ -771,19 +764,15 @@ async def stream_chat(
defiltered_query = result[2]
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(
set([c.get("compiled", c).split("\n")[0] for c in compiled_references])
)
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
user
):
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield event
yield result
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
@@ -807,9 +796,7 @@ async def stream_chat(
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, "status")
):
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
@@ -932,22 +919,17 @@ async def stream_chat(
return
if not connection_alive or not continue_stream:
continue
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
item = item.split("### compiled references:")[0]
try:
async for result in send_event("message", f"{item}"):
yield result
except Exception as e:
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
except asyncio.CancelledError:
logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True)
return
except Exception as e:
logger.error(f"General Error in API endpoint: {e}", exc_info=True)
return
return StreamingResponse(event_generator(q), media_type="text/plain")