mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Simplify advanced streaming chat API, align params with normal chat API
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
@@ -529,29 +528,47 @@ async def set_conversation_title(
|
|||||||
@api_chat.get("/stream")
|
@api_chat.get("/stream")
|
||||||
async def stream_chat(
|
async def stream_chat(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
q: str,
|
q: str,
|
||||||
conversation_id: int,
|
n: int = 7,
|
||||||
|
d: float = 0.18,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
conversation_id: Optional[int] = None,
|
||||||
city: Optional[str] = None,
|
city: Optional[str] = None,
|
||||||
region: Optional[str] = None,
|
region: Optional[str] = None,
|
||||||
country: Optional[str] = None,
|
country: Optional[str] = None,
|
||||||
timezone: Optional[str] = None,
|
timezone: Optional[str] = None,
|
||||||
|
rate_limiter_per_minute=Depends(
|
||||||
|
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||||
|
),
|
||||||
|
rate_limiter_per_day=Depends(
|
||||||
|
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||||
|
),
|
||||||
):
|
):
|
||||||
async def event_generator(q: str):
|
async def event_generator(q: str):
|
||||||
connection_alive = True
|
connection_alive = True
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
q = unquote(q)
|
||||||
|
|
||||||
async def send_event(event_type: str, data: str):
|
async def send_event(event_type: str, data: str):
|
||||||
nonlocal connection_alive
|
nonlocal connection_alive
|
||||||
if not connection_alive or await request.is_disconnected():
|
if not connection_alive or await request.is_disconnected():
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
|
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if event_type == "message":
|
if event_type == "message":
|
||||||
yield data
|
yield data
|
||||||
else:
|
else:
|
||||||
yield json.dumps({"type": event_type, "data": data})
|
yield json.dumps({"type": event_type, "data": data})
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
connection_alive = False
|
||||||
|
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||||
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
|
||||||
|
return
|
||||||
|
|
||||||
async def send_llm_response(response: str):
|
async def send_llm_response(response: str):
|
||||||
async for result in send_event("start_llm_response", ""):
|
async for result in send_event("start_llm_response", ""):
|
||||||
@@ -561,43 +578,23 @@ async def stream_chat(
|
|||||||
async for result in send_event("end_llm_response", ""):
|
async for result in send_event("end_llm_response", ""):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
user: KhojUser = request.user.object
|
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
user, client_application=request.user.client_app, conversation_id=conversation_id
|
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
|
||||||
)
|
)
|
||||||
|
if not conversation:
|
||||||
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"):
|
||||||
|
yield result
|
||||||
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
|
||||||
|
|
||||||
await is_ready_to_chat(user)
|
await is_ready_to_chat(user)
|
||||||
|
|
||||||
user_name = await aget_user_name(user)
|
user_name = await aget_user_name(user)
|
||||||
|
|
||||||
location = None
|
location = None
|
||||||
|
|
||||||
if city or region or country:
|
if city or region or country:
|
||||||
location = LocationData(city=city, region=region, country=country)
|
location = LocationData(city=city, region=region, country=country)
|
||||||
|
|
||||||
while connection_alive:
|
|
||||||
try:
|
|
||||||
if conversation:
|
|
||||||
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
|
|
||||||
|
|
||||||
# Refresh these because the connection to the database might have been closed
|
|
||||||
await conversation.arefresh_from_db()
|
|
||||||
|
|
||||||
try:
|
|
||||||
await sync_to_async(hourly_limiter)(request)
|
|
||||||
await sync_to_async(daily_limiter)(request)
|
|
||||||
except HTTPException as e:
|
|
||||||
async for result in send_event("rate_limit", e.detail):
|
|
||||||
yield result
|
|
||||||
return
|
|
||||||
|
|
||||||
if is_query_empty(q):
|
if is_query_empty(q):
|
||||||
async for event in send_llm_response("Please ask your query to get started."):
|
async for result in send_llm_response("Please ask your query to get started."):
|
||||||
yield event
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
@@ -609,8 +606,6 @@ async def stream_chat(
|
|||||||
meta_log = conversation.conversation_log
|
meta_log = conversation.conversation_log
|
||||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||||
|
|
||||||
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
||||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
@@ -629,6 +624,7 @@ async def stream_chat(
|
|||||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||||
q = q.replace(f"/{cmd.value}", "").strip()
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
|
|
||||||
|
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
||||||
file_filters = conversation.file_filters if conversation else []
|
file_filters = conversation.file_filters if conversation else []
|
||||||
# Skip trying to summarize if
|
# Skip trying to summarize if
|
||||||
if (
|
if (
|
||||||
@@ -643,9 +639,7 @@ async def stream_chat(
|
|||||||
elif ConversationCommand.Summarize in conversation_commands:
|
elif ConversationCommand.Summarize in conversation_commands:
|
||||||
response_log = ""
|
response_log = ""
|
||||||
if len(file_filters) == 0:
|
if len(file_filters) == 0:
|
||||||
response_log = (
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||||
"No files selected for summarization. Please add files using the section on the left."
|
|
||||||
)
|
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
elif len(file_filters) > 1:
|
elif len(file_filters) > 1:
|
||||||
@@ -702,12 +696,11 @@ async def stream_chat(
|
|||||||
if conversation_config == None:
|
if conversation_config == None:
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
model_type = conversation_config.model_type
|
model_type = conversation_config.model_type
|
||||||
formatted_help = help_message.format(
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||||
model=model_type, version=state.khoj_version, device=get_device()
|
|
||||||
)
|
|
||||||
async for result in send_llm_response(formatted_help):
|
async for result in send_llm_response(formatted_help):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
# Adding specification to search online specifically on khoj.dev pages.
|
||||||
custom_filters.append("site:khoj.dev")
|
custom_filters.append("site:khoj.dev")
|
||||||
conversation_commands.append(ConversationCommand.Online)
|
conversation_commands.append(ConversationCommand.Online)
|
||||||
|
|
||||||
@@ -756,8 +749,8 @@ async def stream_chat(
|
|||||||
request,
|
request,
|
||||||
meta_log,
|
meta_log,
|
||||||
q,
|
q,
|
||||||
7,
|
(n or 7),
|
||||||
0.18,
|
(d or 0.18),
|
||||||
conversation_id,
|
conversation_id,
|
||||||
conversation_commands,
|
conversation_commands,
|
||||||
location,
|
location,
|
||||||
@@ -771,19 +764,15 @@ async def stream_chat(
|
|||||||
defiltered_query = result[2]
|
defiltered_query = result[2]
|
||||||
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
headings = "\n- " + "\n- ".join(
|
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||||
set([c.get("compiled", c).split("\n")[0] for c in compiled_references])
|
|
||||||
)
|
|
||||||
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
|
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
online_results: Dict = dict()
|
online_results: Dict = dict()
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
user
|
|
||||||
):
|
|
||||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||||
yield event
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||||
@@ -807,9 +796,7 @@ async def stream_chat(
|
|||||||
|
|
||||||
if ConversationCommand.Webpage in conversation_commands:
|
if ConversationCommand.Webpage in conversation_commands:
|
||||||
try:
|
try:
|
||||||
async for result in read_webpages(
|
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
|
||||||
defiltered_query, meta_log, location, partial(send_event, "status")
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and "status" in result:
|
if isinstance(result, dict) and "status" in result:
|
||||||
yield result["status"]
|
yield result["status"]
|
||||||
else:
|
else:
|
||||||
@@ -932,22 +919,17 @@ async def stream_chat(
|
|||||||
return
|
return
|
||||||
if not connection_alive or not continue_stream:
|
if not connection_alive or not continue_stream:
|
||||||
continue
|
continue
|
||||||
|
# Stop streaming after compiled references section of response starts
|
||||||
|
# References are being processed via the references event rather than the message event
|
||||||
|
if "### compiled references:" in item:
|
||||||
|
continue_stream = False
|
||||||
|
item = item.split("### compiled references:")[0]
|
||||||
try:
|
try:
|
||||||
async for result in send_event("message", f"{item}"):
|
async for result in send_event("message", f"{item}"):
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue_stream = False
|
continue_stream = False
|
||||||
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
||||||
# Stop streaming after compiled references section of response starts
|
|
||||||
# References are being processed via the references event rather than the message event
|
|
||||||
if "### compiled references:" in item:
|
|
||||||
continue_stream = False
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True)
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"General Error in API endpoint: {e}", exc_info=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
return StreamingResponse(event_generator(q), media_type="text/plain")
|
return StreamingResponse(event_generator(q), media_type="text/plain")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user