Use enum to track chat stream event types in chat api router

This commit is contained in:
Debanjum Singh Solanky
2024-07-26 00:18:37 +05:30
parent ebe92ef16d
commit 778c571288
4 changed files with 56 additions and 42 deletions

View File

@@ -11,6 +11,7 @@ from bs4 import BeautifulSoup
from markdownify import markdownify from markdownify import markdownify
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent,
extract_relevant_info, extract_relevant_info,
generate_online_subqueries, generate_online_subqueries,
infer_webpage_urls, infer_webpage_urls,
@@ -68,7 +69,7 @@ async def search_online(
if send_status_func: if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries)) subqueries_str = "\n- " + "\n- ".join(list(subqueries))
async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"): async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
yield {"status": event} yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {list(subqueries)} took", logger): with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
@@ -92,7 +93,7 @@ async def search_online(
if send_status_func: if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {"status": event} yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@@ -131,14 +132,14 @@ async def read_webpages(
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
if send_status_func: if send_status_func:
async for event in send_status_func(f"**🧐 Inferring web pages to read**"): async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
yield {"status": event} yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location) urls = await infer_webpage_urls(query, conversation_history, location)
logger.info(f"Reading web pages at: {urls}") logger.info(f"Reading web pages at: {urls}")
if send_status_func: if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(urls)) webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {"status": event} yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url) for url in urls] tasks = [read_webpage_and_extract_content(query, url) for url in urls]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)

View File

@@ -36,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
ChatEvent,
CommonQueryParams, CommonQueryParams,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
acreate_title_from_query, acreate_title_from_query,
@@ -375,7 +376,7 @@ async def extract_references_and_questions(
if send_status_func: if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"): async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
yield {"status": event} yield {ChatEvent.STATUS: event}
for query in inferred_queries: for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n n_items = min(n, 3) if using_offline_chat else n
search_results.extend( search_results.extend(

View File

@@ -30,6 +30,7 @@ from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
ChatEvent,
CommonQueryParams, CommonQueryParams,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
agenerate_chat_response, agenerate_chat_response,
@@ -551,24 +552,24 @@ async def chat(
event_delimiter = "␃🔚␗" event_delimiter = "␃🔚␗"
q = unquote(q) q = unquote(q)
async def send_event(event_type: str, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft nonlocal connection_alive, ttft
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
connection_alive = False connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client") logger.warn(f"User {user} disconnected from {common.client} client")
return return
try: try:
if event_type == "end_llm_response": if event_type == ChatEvent.END_LLM_RESPONSE:
collect_telemetry() collect_telemetry()
if event_type == "start_llm_response": if event_type == ChatEvent.START_LLM_RESPONSE:
ttft = time.perf_counter() - start_time ttft = time.perf_counter() - start_time
if event_type == "message": if event_type == ChatEvent.MESSAGE:
yield data yield data
elif event_type == "references" or stream: elif event_type == ChatEvent.REFERENCES or stream:
yield json.dumps({"type": event_type, "data": data}, ensure_ascii=False) yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError: except asyncio.CancelledError as e:
connection_alive = False connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client") logger.warn(f"User {user} disconnected from {common.client} client: {e}")
return return
except Exception as e: except Exception as e:
connection_alive = False connection_alive = False
@@ -579,11 +580,11 @@ async def chat(
yield event_delimiter yield event_delimiter
async def send_llm_response(response: str): async def send_llm_response(response: str):
async for result in send_event("start_llm_response", ""): async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result yield result
async for result in send_event("message", response): async for result in send_event(ChatEvent.MESSAGE, response):
yield result yield result
async for result in send_event("end_llm_response", ""): async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result yield result
def collect_telemetry(): def collect_telemetry():
@@ -632,7 +633,7 @@ async def chat(
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation_commands = [get_conversation_command(query=q, any_references=True)]
async for result in send_event("status", f"**👀 Understanding Query**: {q}"): async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"):
yield result yield result
meta_log = conversation.conversation_log meta_log = conversation.conversation_log
@@ -642,12 +643,12 @@ async def chat(
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event( async for result in send_event(
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
): ):
yield result yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): async for result in send_event(ChatEvent.STATUS, f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
yield result yield result
if mode not in conversation_commands: if mode not in conversation_commands:
conversation_commands.append(mode) conversation_commands.append(mode)
@@ -690,7 +691,7 @@ async def chat(
if not q: if not q:
q = "Create a general summary of the file" q = "Create a general summary of the file"
async for result in send_event( async for result in send_event(
"status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" ChatEvent.STATUS, f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
): ):
yield result yield result
@@ -771,10 +772,10 @@ async def chat(
conversation_id, conversation_id,
conversation_commands, conversation_commands,
location, location,
partial(send_event, "status"), partial(send_event, ChatEvent.STATUS),
): ):
if isinstance(result, dict) and "status" in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result["status"] yield result[ChatEvent.STATUS]
else: else:
compiled_references.extend(result[0]) compiled_references.extend(result[0])
inferred_queries.extend(result[1]) inferred_queries.extend(result[1])
@@ -782,7 +783,7 @@ async def chat(
if not is_none_or_empty(compiled_references): if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"):
yield result yield result
online_results: Dict = dict() online_results: Dict = dict()
@@ -799,10 +800,10 @@ async def chat(
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
try: try:
async for result in search_online( async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
): ):
if isinstance(result, dict) and "status" in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result["status"] yield result[ChatEvent.STATUS]
else: else:
online_results = result online_results = result
except ValueError as e: except ValueError as e:
@@ -815,9 +816,11 @@ async def chat(
## Gather Webpage References ## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")): async for result in read_webpages(
if isinstance(result, dict) and "status" in result: defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
yield result["status"] ):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else: else:
direct_web_pages = result direct_web_pages = result
webpages = [] webpages = []
@@ -829,7 +832,7 @@ async def chat(
for webpage in direct_web_pages[query]["webpages"]: for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"]) webpages.append(webpage["link"])
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"):
yield result yield result
except ValueError as e: except ValueError as e:
logger.warning( logger.warning(
@@ -839,7 +842,7 @@ async def chat(
## Send Gathered References ## Send Gathered References
async for result in send_event( async for result in send_event(
"references", ChatEvent.REFERENCES,
{ {
"inferredQueries": inferred_queries, "inferredQueries": inferred_queries,
"context": compiled_references, "context": compiled_references,
@@ -858,10 +861,10 @@ async def chat(
location_data=location, location_data=location,
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
send_status_func=partial(send_event, "status"), send_status_func=partial(send_event, ChatEvent.STATUS),
): ):
if isinstance(result, dict) and "status" in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result["status"] yield result[ChatEvent.STATUS]
else: else:
image, status_code, improved_image_prompt, intent_type = result image, status_code, improved_image_prompt, intent_type = result
@@ -899,7 +902,7 @@ async def chat(
return return
## Generate Text Output ## Generate Text Output
async for result in send_event("status", f"**💭 Generating a well-informed response**"): async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"):
yield result yield result
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query, defiltered_query,
@@ -917,21 +920,21 @@ async def chat(
) )
# Send Response # Send Response
async for result in send_event("start_llm_response", ""): async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result yield result
continue_stream = True continue_stream = True
iterator = AsyncIteratorWrapper(llm_response) iterator = AsyncIteratorWrapper(llm_response)
async for item in iterator: async for item in iterator:
if item is None: if item is None:
async for result in send_event("end_llm_response", ""): async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result yield result
logger.debug("Finished streaming response") logger.debug("Finished streaming response")
return return
if not connection_alive or not continue_stream: if not connection_alive or not continue_stream:
continue continue
try: try:
async for result in send_event("message", f"{item}"): async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
yield result yield result
except Exception as e: except Exception as e:
continue_stream = False continue_stream = False
@@ -949,7 +952,7 @@ async def chat(
async for item in iterator: async for item in iterator:
try: try:
item_json = json.loads(item) item_json = json.loads(item)
if "type" in item_json and item_json["type"] == "references": if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
response_obj = item_json["data"] response_obj = item_json["data"]
except: except:
actual_response += item actual_response += item

View File

@@ -8,6 +8,7 @@ import math
import re import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import partial from functools import partial
from random import random from random import random
from typing import ( from typing import (
@@ -782,7 +783,7 @@ async def text_to_image(
if send_status_func: if send_status_func:
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
yield {"status": event} yield {ChatEvent.STATUS: event}
improved_image_prompt = await generate_better_image_prompt( improved_image_prompt = await generate_better_image_prompt(
message, message,
chat_history, chat_history,
@@ -794,7 +795,7 @@ async def text_to_image(
if send_status_func: if send_status_func:
async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"): async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
yield {"status": event} yield {ChatEvent.STATUS: event}
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger): with timer("Generate image with OpenAI", logger):
@@ -1191,3 +1192,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t
Manage your automations [here](/automations). Manage your automations [here](/automations).
""".strip() """.strip()
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"