diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 1f8a5c9e..c087de70 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -11,6 +11,7 @@ from bs4 import BeautifulSoup from markdownify import markdownify from khoj.routers.helpers import ( + ChatEvent, extract_relevant_info, generate_online_subqueries, infer_webpage_urls, @@ -68,7 +69,7 @@ async def search_online( if send_status_func: subqueries_str = "\n- " + "\n- ".join(list(subqueries)) async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} with timer(f"Internet searches for {list(subqueries)} took", logger): search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina @@ -92,7 +93,7 @@ async def search_online( if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] results = await asyncio.gather(*tasks) @@ -131,14 +132,14 @@ async def read_webpages( logger.info(f"Inferring web pages to read") if send_status_func: async for event in send_status_func(f"**🧐 Inferring web pages to read**"): - yield {"status": event} + yield {ChatEvent.STATUS: event} urls = await infer_webpage_urls(query, conversation_history, location) logger.info(f"Reading web pages at: {urls}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} tasks = [read_webpage_and_extract_content(query, url) for url in urls] results = await asyncio.gather(*tasks) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 836b963f..81599dd6 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -36,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.routers.helpers import ( ApiUserRateLimiter, + ChatEvent, CommonQueryParams, ConversationCommandRateLimiter, acreate_title_from_query, @@ -375,7 +376,7 @@ async def extract_references_and_questions( if send_status_func: inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} for query in inferred_queries: n_items = min(n, 3) if using_offline_chat else n search_results.extend( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 5e1cb1a8..63529b8e 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -30,6 +30,7 @@ from khoj.processor.tools.online_search import read_webpages, search_online from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ApiUserRateLimiter, + ChatEvent, CommonQueryParams, ConversationCommandRateLimiter, agenerate_chat_response, @@ -551,24 +552,24 @@ async def chat( event_delimiter = "␃🔚␗" q = unquote(q) - async def send_event(event_type: str, data: str | dict): + async def send_event(event_type: ChatEvent, data: str | dict): nonlocal connection_alive, ttft if not connection_alive or await request.is_disconnected(): connection_alive = False logger.warn(f"User {user} disconnected from {common.client} client") return try: - if event_type == "end_llm_response": + if event_type == ChatEvent.END_LLM_RESPONSE: collect_telemetry() - if event_type == "start_llm_response": + if event_type == ChatEvent.START_LLM_RESPONSE: ttft = time.perf_counter() - start_time - if event_type == "message": + if event_type == ChatEvent.MESSAGE: yield data - elif event_type == "references" or stream: - yield json.dumps({"type": event_type, "data": data}, ensure_ascii=False) - except asyncio.CancelledError: + elif event_type == ChatEvent.REFERENCES or stream: + yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) + except asyncio.CancelledError as e: connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client") + logger.warn(f"User {user} disconnected from {common.client} client: {e}") return except Exception as e: connection_alive = False @@ -579,11 +580,11 @@ async def chat( yield event_delimiter async def send_llm_response(response: str): - async for result in send_event("start_llm_response", ""): + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): yield result - async for result in send_event("message", response): + async for result in send_event(ChatEvent.MESSAGE, response): yield result - async for result in send_event("end_llm_response", ""): + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result def collect_telemetry(): @@ -632,7 +633,7 @@ async def chat( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") conversation_commands = [get_conversation_command(query=q, any_references=True)] - async for result in send_event("status", f"**👀 Understanding Query**: {q}"): + async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"): yield result meta_log = conversation.conversation_log @@ -642,12 +643,12 @@ async def chat( conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) async for result in send_event( - "status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" ): yield result mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + async for result in send_event(ChatEvent.STATUS, f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: conversation_commands.append(mode) @@ -690,7 +691,7 @@ async def chat( if not q: q = "Create a general summary of the file" async for result in send_event( - "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" + ChatEvent.STATUS, f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" ): yield result @@ -771,10 +772,10 @@ async def chat( conversation_id, conversation_commands, location, - partial(send_event, "status"), + partial(send_event, ChatEvent.STATUS), ): - if isinstance(result, dict) and "status" in result: - yield result["status"] + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] else: compiled_references.extend(result[0]) inferred_queries.extend(result[1]) @@ -782,7 +783,7 @@ async def chat( if not is_none_or_empty(compiled_references): headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): + async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"): yield result online_results: Dict = dict() @@ -799,10 +800,10 @@ async def chat( if ConversationCommand.Online in conversation_commands: try: async for result in search_online( - defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters + defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters ): - if isinstance(result, dict) and "status" in result: - yield result["status"] + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] else: online_results = result except ValueError as e: @@ -815,9 +816,11 @@ async def chat( ## Gather Webpage References if ConversationCommand.Webpage in conversation_commands: try: - async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")): - if isinstance(result, dict) and "status" in result: - yield result["status"] + async for result in read_webpages( + defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS) + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] else: direct_web_pages = result webpages = [] @@ -829,7 +832,7 @@ async def chat( for webpage in direct_web_pages[query]["webpages"]: webpages.append(webpage["link"]) - async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): + async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"): yield result except ValueError as e: logger.warning( @@ -839,7 +842,7 @@ async def chat( ## Send Gathered References async for result in send_event( - "references", + ChatEvent.REFERENCES, { "inferredQueries": inferred_queries, "context": compiled_references, @@ -858,10 +861,10 @@ async def chat( location_data=location, references=compiled_references, online_results=online_results, - send_status_func=partial(send_event, "status"), + send_status_func=partial(send_event, ChatEvent.STATUS), ): - if isinstance(result, dict) and "status" in result: - yield result["status"] + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] else: image, status_code, improved_image_prompt, intent_type = result @@ -899,7 +902,7 @@ async def chat( return ## Generate Text Output - async for result in send_event("status", f"**💭 Generating a well-informed response**"): + async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( defiltered_query, @@ -917,21 +920,21 @@ async def chat( ) # Send Response - async for result in send_event("start_llm_response", ""): + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): yield result continue_stream = True iterator = AsyncIteratorWrapper(llm_response) async for item in iterator: if item is None: - async for result in send_event("end_llm_response", ""): + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result logger.debug("Finished streaming response") return if not connection_alive or not continue_stream: continue try: - async for result in send_event("message", f"{item}"): + async for result in send_event(ChatEvent.MESSAGE, f"{item}"): yield result except Exception as e: continue_stream = False @@ -949,7 +952,7 @@ async def chat( async for item in iterator: try: item_json = json.loads(item) - if "type" in item_json and item_json["type"] == "references": + if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value: response_obj = item_json["data"] except: actual_response += item diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 7b8af5d9..538b571b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -8,6 +8,7 @@ import math import re from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone +from enum import Enum from functools import partial from random import random from typing import ( @@ -782,7 +783,7 @@ async def text_to_image( if send_status_func: async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): - yield {"status": event} + yield {ChatEvent.STATUS: event} improved_image_prompt = await generate_better_image_prompt( message, chat_history, @@ -794,7 +795,7 @@ async def text_to_image( if send_status_func: async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: with timer("Generate image with OpenAI", logger): @@ -1191,3 +1192,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t Manage your automations [here](/automations). """.strip() + + +class ChatEvent(Enum): + START_LLM_RESPONSE = "start_llm_response" + END_LLM_RESPONSE = "end_llm_response" + MESSAGE = "message" + REFERENCES = "references" + STATUS = "status"