mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 13:26:13 +00:00
Use enum to track chat stream event types in chat api router
This commit is contained in:
@@ -11,6 +11,7 @@ from bs4 import BeautifulSoup
|
|||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
|
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
|
ChatEvent,
|
||||||
extract_relevant_info,
|
extract_relevant_info,
|
||||||
generate_online_subqueries,
|
generate_online_subqueries,
|
||||||
infer_webpage_urls,
|
infer_webpage_urls,
|
||||||
@@ -68,7 +69,7 @@ async def search_online(
|
|||||||
if send_status_func:
|
if send_status_func:
|
||||||
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
||||||
async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
|
async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
|
||||||
yield {"status": event}
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||||
@@ -92,7 +93,7 @@ async def search_online(
|
|||||||
if send_status_func:
|
if send_status_func:
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||||
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
||||||
yield {"status": event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
|
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
@@ -131,14 +132,14 @@ async def read_webpages(
|
|||||||
logger.info(f"Inferring web pages to read")
|
logger.info(f"Inferring web pages to read")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
|
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
|
||||||
yield {"status": event}
|
yield {ChatEvent.STATUS: event}
|
||||||
urls = await infer_webpage_urls(query, conversation_history, location)
|
urls = await infer_webpage_urls(query, conversation_history, location)
|
||||||
|
|
||||||
logger.info(f"Reading web pages at: {urls}")
|
logger.info(f"Reading web pages at: {urls}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||||
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
|
||||||
yield {"status": event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions
|
|||||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
|
ChatEvent,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
acreate_title_from_query,
|
acreate_title_from_query,
|
||||||
@@ -375,7 +376,7 @@ async def extract_references_and_questions(
|
|||||||
if send_status_func:
|
if send_status_func:
|
||||||
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
|
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
|
||||||
async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
|
async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
|
||||||
yield {"status": event}
|
yield {ChatEvent.STATUS: event}
|
||||||
for query in inferred_queries:
|
for query in inferred_queries:
|
||||||
n_items = min(n, 3) if using_offline_chat else n
|
n_items = min(n, 3) if using_offline_chat else n
|
||||||
search_results.extend(
|
search_results.extend(
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from khoj.processor.tools.online_search import read_webpages, search_online
|
|||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
|
ChatEvent,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
@@ -551,24 +552,24 @@ async def chat(
|
|||||||
event_delimiter = "␃🔚␗"
|
event_delimiter = "␃🔚␗"
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
|
|
||||||
async def send_event(event_type: str, data: str | dict):
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
nonlocal connection_alive, ttft
|
nonlocal connection_alive, ttft
|
||||||
if not connection_alive or await request.is_disconnected():
|
if not connection_alive or await request.is_disconnected():
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
logger.warn(f"User {user} disconnected from {common.client} client")
|
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if event_type == "end_llm_response":
|
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||||
collect_telemetry()
|
collect_telemetry()
|
||||||
if event_type == "start_llm_response":
|
if event_type == ChatEvent.START_LLM_RESPONSE:
|
||||||
ttft = time.perf_counter() - start_time
|
ttft = time.perf_counter() - start_time
|
||||||
if event_type == "message":
|
if event_type == ChatEvent.MESSAGE:
|
||||||
yield data
|
yield data
|
||||||
elif event_type == "references" or stream:
|
elif event_type == ChatEvent.REFERENCES or stream:
|
||||||
yield json.dumps({"type": event_type, "data": data}, ensure_ascii=False)
|
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError as e:
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
logger.warn(f"User {user} disconnected from {common.client} client")
|
logger.warn(f"User {user} disconnected from {common.client} client: {e}")
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
@@ -579,11 +580,11 @@ async def chat(
|
|||||||
yield event_delimiter
|
yield event_delimiter
|
||||||
|
|
||||||
async def send_llm_response(response: str):
|
async def send_llm_response(response: str):
|
||||||
async for result in send_event("start_llm_response", ""):
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
async for result in send_event("message", response):
|
async for result in send_event(ChatEvent.MESSAGE, response):
|
||||||
yield result
|
yield result
|
||||||
async for result in send_event("end_llm_response", ""):
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
def collect_telemetry():
|
def collect_telemetry():
|
||||||
@@ -632,7 +633,7 @@ async def chat(
|
|||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||||
|
|
||||||
async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
|
async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
meta_log = conversation.conversation_log
|
meta_log = conversation.conversation_log
|
||||||
@@ -642,12 +643,12 @@ async def chat(
|
|||||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
||||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
|
ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
||||||
async for result in send_event("status", f"**🧑🏾💻 Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**🧑🏾💻 Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
conversation_commands.append(mode)
|
conversation_commands.append(mode)
|
||||||
@@ -690,7 +691,7 @@ async def chat(
|
|||||||
if not q:
|
if not q:
|
||||||
q = "Create a general summary of the file"
|
q = "Create a general summary of the file"
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
"status", f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}"
|
ChatEvent.STATUS, f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}"
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
@@ -771,10 +772,10 @@ async def chat(
|
|||||||
conversation_id,
|
conversation_id,
|
||||||
conversation_commands,
|
conversation_commands,
|
||||||
location,
|
location,
|
||||||
partial(send_event, "status"),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and "status" in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result["status"]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
compiled_references.extend(result[0])
|
compiled_references.extend(result[0])
|
||||||
inferred_queries.extend(result[1])
|
inferred_queries.extend(result[1])
|
||||||
@@ -782,7 +783,7 @@ async def chat(
|
|||||||
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||||
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
|
async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
online_results: Dict = dict()
|
online_results: Dict = dict()
|
||||||
@@ -799,10 +800,10 @@ async def chat(
|
|||||||
if ConversationCommand.Online in conversation_commands:
|
if ConversationCommand.Online in conversation_commands:
|
||||||
try:
|
try:
|
||||||
async for result in search_online(
|
async for result in search_online(
|
||||||
defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
|
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and "status" in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result["status"]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
online_results = result
|
online_results = result
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -815,9 +816,11 @@ async def chat(
|
|||||||
## Gather Webpage References
|
## Gather Webpage References
|
||||||
if ConversationCommand.Webpage in conversation_commands:
|
if ConversationCommand.Webpage in conversation_commands:
|
||||||
try:
|
try:
|
||||||
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
|
async for result in read_webpages(
|
||||||
if isinstance(result, dict) and "status" in result:
|
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
|
||||||
yield result["status"]
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
direct_web_pages = result
|
direct_web_pages = result
|
||||||
webpages = []
|
webpages = []
|
||||||
@@ -829,7 +832,7 @@ async def chat(
|
|||||||
|
|
||||||
for webpage in direct_web_pages[query]["webpages"]:
|
for webpage in direct_web_pages[query]["webpages"]:
|
||||||
webpages.append(webpage["link"])
|
webpages.append(webpage["link"])
|
||||||
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
|
async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"):
|
||||||
yield result
|
yield result
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -839,7 +842,7 @@ async def chat(
|
|||||||
|
|
||||||
## Send Gathered References
|
## Send Gathered References
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
"references",
|
ChatEvent.REFERENCES,
|
||||||
{
|
{
|
||||||
"inferredQueries": inferred_queries,
|
"inferredQueries": inferred_queries,
|
||||||
"context": compiled_references,
|
"context": compiled_references,
|
||||||
@@ -858,10 +861,10 @@ async def chat(
|
|||||||
location_data=location,
|
location_data=location,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
send_status_func=partial(send_event, "status"),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and "status" in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result["status"]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
image, status_code, improved_image_prompt, intent_type = result
|
image, status_code, improved_image_prompt, intent_type = result
|
||||||
|
|
||||||
@@ -899,7 +902,7 @@ async def chat(
|
|||||||
return
|
return
|
||||||
|
|
||||||
## Generate Text Output
|
## Generate Text Output
|
||||||
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
|
async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"):
|
||||||
yield result
|
yield result
|
||||||
llm_response, chat_metadata = await agenerate_chat_response(
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
@@ -917,21 +920,21 @@ async def chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
# Send Response
|
||||||
async for result in send_event("start_llm_response", ""):
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
continue_stream = True
|
continue_stream = True
|
||||||
iterator = AsyncIteratorWrapper(llm_response)
|
iterator = AsyncIteratorWrapper(llm_response)
|
||||||
async for item in iterator:
|
async for item in iterator:
|
||||||
if item is None:
|
if item is None:
|
||||||
async for result in send_event("end_llm_response", ""):
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
logger.debug("Finished streaming response")
|
logger.debug("Finished streaming response")
|
||||||
return
|
return
|
||||||
if not connection_alive or not continue_stream:
|
if not connection_alive or not continue_stream:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
async for result in send_event("message", f"{item}"):
|
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue_stream = False
|
continue_stream = False
|
||||||
@@ -949,7 +952,7 @@ async def chat(
|
|||||||
async for item in iterator:
|
async for item in iterator:
|
||||||
try:
|
try:
|
||||||
item_json = json.loads(item)
|
item_json = json.loads(item)
|
||||||
if "type" in item_json and item_json["type"] == "references":
|
if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
|
||||||
response_obj = item_json["data"]
|
response_obj = item_json["data"]
|
||||||
except:
|
except:
|
||||||
actual_response += item
|
actual_response += item
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import math
|
|||||||
import re
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from random import random
|
from random import random
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -782,7 +783,7 @@ async def text_to_image(
|
|||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
|
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
|
||||||
yield {"status": event}
|
yield {ChatEvent.STATUS: event}
|
||||||
improved_image_prompt = await generate_better_image_prompt(
|
improved_image_prompt = await generate_better_image_prompt(
|
||||||
message,
|
message,
|
||||||
chat_history,
|
chat_history,
|
||||||
@@ -794,7 +795,7 @@ async def text_to_image(
|
|||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
|
async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
|
||||||
yield {"status": event}
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
with timer("Generate image with OpenAI", logger):
|
with timer("Generate image with OpenAI", logger):
|
||||||
@@ -1191,3 +1192,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t
|
|||||||
|
|
||||||
Manage your automations [here](/automations).
|
Manage your automations [here](/automations).
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatEvent(Enum):
|
||||||
|
START_LLM_RESPONSE = "start_llm_response"
|
||||||
|
END_LLM_RESPONSE = "end_llm_response"
|
||||||
|
MESSAGE = "message"
|
||||||
|
REFERENCES = "references"
|
||||||
|
STATUS = "status"
|
||||||
|
|||||||
Reference in New Issue
Block a user