From 38dd85c91fd53c216a90acee432e816cf0082912 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 18 Jun 2025 16:45:09 -0700 Subject: [PATCH 1/6] Add websocket chat api endpoint to ease bi-directional communication --- src/khoj/main.py | 12 +- src/khoj/routers/api_chat.py | 1543 ++++++++++++++++++---------------- src/khoj/routers/helpers.py | 92 +- 3 files changed, 939 insertions(+), 708 deletions(-) diff --git a/src/khoj/main.py b/src/khoj/main.py index 76895891..50da2624 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -220,7 +220,16 @@ def set_state(args): def start_server(app, host=None, port=None, socket=None): logger.info("🌖 Khoj is ready to engage") if socket: - uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None) + uvicorn.run( + app, + proxy_headers=True, + uds=socket, + log_level="debug" if state.verbose > 1 else "info", + use_colors=True, + log_config=None, + ws_ping_timeout=300, + timeout_keep_alive=60, + ) else: uvicorn.run( app, @@ -229,6 +238,7 @@ def start_server(app, host=None, port=None, socket=None): log_level="debug" if state.verbose > 1 else "info", use_colors=True, log_config=None, + ws_ping_timeout=300, timeout_keep_alive=60, **state.ssl_config if state.ssl_config else {}, ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index b1ea3ece..04b50bfd 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -10,9 +10,18 @@ from typing import Any, Dict, List, Optional from urllib.parse import unquote from asgiref.sync import sync_to_async -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + WebSocket, + WebSocketDisconnect, +) from fastapi.responses import RedirectResponse, Response, StreamingResponse +from fastapi.websockets import WebSocketState from starlette.authentication import has_required_scope, requires +from starlette.requests import Headers from khoj.app.settings import ALLOWED_HOSTS from khoj.database.adapters import ( @@ -657,19 +666,12 @@ def delete_message(request: Request, delete_request: DeleteMessageRequestBody) - return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404) -@api_chat.post("") -@requires(["authenticated"]) -async def chat( - request: Request, - common: CommonQueryParams, +async def event_generator( body: ChatRequestBody, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - ), - image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)), + user_scope: Any, + common: CommonQueryParams, + headers: Headers, + request_obj: Request | WebSocket, ): # Access the parameters from the body q = body.q @@ -688,65 +690,62 @@ async def chat( raw_query_files = body.files interrupt_flag = body.interrupt - async def event_generator(q: str, images: list[str]): - start_time = time.perf_counter() - ttft = None - chat_metadata: dict = {} - conversation = None - user: KhojUser = request.user.object - is_subscribed = has_required_scope(request, ["premium"]) - q = unquote(q) - train_of_thought = [] - nonlocal conversation_id - nonlocal raw_query_files - cancellation_event = asyncio.Event() + start_time = time.perf_counter() + ttft = None + chat_metadata: dict = {} + conversation = None + user: KhojUser = user_scope.object + is_subscribed = has_required_scope(request_obj, ["premium"]) + q = unquote(q) + train_of_thought = [] + cancellation_event = asyncio.Event() - tracer: dict = { - "mid": turn_id, - "cid": conversation_id, - "uid": user.id, - "khoj_version": state.khoj_version, - } + tracer: dict = { + "mid": turn_id, + "cid": conversation_id, + "uid": user.id, + "khoj_version": state.khoj_version, + } - uploaded_images: list[str] = [] - if images: - for image in images: - decoded_string = unquote(image) - base64_data = decoded_string.split(",", 1)[1] - image_bytes = base64.b64decode(base64_data) - webp_image_bytes = convert_image_to_webp(image_bytes) - uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id) - if not uploaded_image: - base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8") - uploaded_image = f"data:image/webp;base64,{base64_webp_image}" - uploaded_images.append(uploaded_image) + uploaded_images: list[str] = [] + if raw_images: + for image in raw_images: + decoded_string = unquote(image) + base64_data = decoded_string.split(",", 1)[1] + image_bytes = base64.b64decode(base64_data) + webp_image_bytes = convert_image_to_webp(image_bytes) + uploaded_image = upload_user_image_to_bucket(webp_image_bytes, user.id) + if not uploaded_image: + base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8") + uploaded_image = f"data:image/webp;base64,{base64_webp_image}" + uploaded_images.append(uploaded_image) - query_files: Dict[str, str] = {} - if raw_query_files: - for file in raw_query_files: - query_files[file.name] = file.content + query_files: Dict[str, str] = {} + if raw_query_files: + for file in raw_query_files: + query_files[file.name] = file.content - research_results: List[ResearchIteration] = [] - online_results: Dict = dict() - code_results: Dict = dict() - operator_results: List[OperatorRun] = [] - compiled_references: List[Any] = [] - inferred_queries: List[Any] = [] - attached_file_context = gather_raw_query_files(query_files) + research_results: List[ResearchIteration] = [] + online_results: Dict = dict() + code_results: Dict = dict() + operator_results: List[OperatorRun] = [] + compiled_references: List[Any] = [] + inferred_queries: List[Any] = [] + attached_file_context = gather_raw_query_files(query_files) - generated_images: List[str] = [] - generated_files: List[FileAttachment] = [] - generated_mermaidjs_diagram: str = None - generated_asset_results: Dict = dict() - program_execution_context: List[str] = [] - chat_history: List[ChatMessageModel] = [] + generated_images: List[str] = [] + generated_files: List[FileAttachment] = [] + generated_mermaidjs_diagram: str = None + generated_asset_results: Dict = dict() + program_execution_context: List[str] = [] - # Create a task to monitor for disconnections - disconnect_monitor_task = None + # Create a task to monitor for disconnections + disconnect_monitor_task = None - async def monitor_disconnection(): + async def monitor_disconnection(): + if isinstance(request_obj, Request): try: - msg = await request.receive() + msg = await request_obj.receive() if msg["type"] == "http.disconnect": logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.") cancellation_event.set() @@ -765,7 +764,7 @@ async def chat( operator_results=operator_results, research_results=research_results, inferred_queries=inferred_queries, - client_application=request.user.client_app, + client_application=user_scope.client_app, conversation_id=conversation_id, query_images=uploaded_images, train_of_thought=train_of_thought, @@ -778,683 +777,817 @@ async def chat( ) except Exception as e: logger.error(f"Error in disconnect monitor: {e}") + elif isinstance(request_obj, WebSocket): + while request_obj.client_state == WebSocketState.CONNECTED: + await asyncio.sleep(1) - # Cancel the disconnect monitor task if it is still running - async def cancel_disconnect_monitor(): - if disconnect_monitor_task and not disconnect_monitor_task.done(): - logger.debug(f"Cancelling disconnect monitor task for user {user}") - disconnect_monitor_task.cancel() - try: - await disconnect_monitor_task - except asyncio.CancelledError: - pass - - async def send_event(event_type: ChatEvent, data: str | dict): - nonlocal ttft, train_of_thought - event_delimiter = "␃🔚␗" - if cancellation_event.is_set(): - return - try: - if event_type == ChatEvent.END_LLM_RESPONSE: - collect_telemetry() - elif event_type == ChatEvent.START_LLM_RESPONSE: - ttft = time.perf_counter() - start_time - elif event_type == ChatEvent.STATUS: - train_of_thought.append({"type": event_type.value, "data": data}) - elif event_type == ChatEvent.THOUGHT: - # Append the data to the last thought as thoughts are streamed - if ( - len(train_of_thought) > 0 - and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value - and type(train_of_thought[-1]["data"]) == type(data) == str - ): - train_of_thought[-1]["data"] += data - else: - train_of_thought.append({"type": event_type.value, "data": data}) - - if event_type == ChatEvent.MESSAGE: - yield data - elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: - yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) - except Exception as e: - if not cancellation_event.is_set(): - logger.error( - f"Failed to stream chat API response to {user} on {common.client}: {e}.", - exc_info=True, + logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.") + cancellation_event.set() + if conversation: + await asyncio.shield( + save_to_conversation_log( + q, + chat_response="", + user=user, + chat_history=chat_history, + compiled_references=compiled_references, + online_results=online_results, + code_results=code_results, + operator_results=operator_results, + research_results=research_results, + inferred_queries=inferred_queries, + client_application=user_scope.client_app, + conversation_id=conversation_id, + query_images=uploaded_images, + train_of_thought=train_of_thought, + raw_query_files=raw_query_files, + generated_images=generated_images, + raw_generated_files=generated_asset_results, + generated_mermaidjs_diagram=generated_mermaidjs_diagram, + tracer=tracer, ) - finally: - if not cancellation_event.is_set(): - yield event_delimiter - # Cancel the disconnect monitor task if it is still running - if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE: - await cancel_disconnect_monitor() - - async def send_llm_response(response: str, usage: dict = None): - # Check if the client is still connected - if cancellation_event.is_set(): - return - # Send Chat Response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - async for result in send_event(ChatEvent.MESSAGE, response): - yield result - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - # Send Usage Metadata once llm interactions are complete - if usage: - async for event in send_event(ChatEvent.USAGE, usage): - yield event - async for result in send_event(ChatEvent.END_RESPONSE, ""): - yield result - - def collect_telemetry(): - # Gather chat response telemetry - nonlocal chat_metadata - latency = time.perf_counter() - start_time - cmd_set = set([cmd.value for cmd in conversation_commands]) - cost = (tracer.get("usage", {}) or {}).get("cost", 0) - chat_metadata = chat_metadata or {} - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None - chat_metadata["cost"] = f"{cost:.5f}" - chat_metadata["latency"] = f"{latency:.3f}" - if ttft: - chat_metadata["ttft_latency"] = f"{ttft:.3f}" - logger.info(f"Chat response time to first token: {ttft:.3f} seconds") - logger.info(f"Chat response total time: {latency:.3f} seconds") - logger.info(f"Chat response cost: ${cost:.5f}") - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - client=common.client, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - metadata=chat_metadata, - ) - - # Start the disconnect monitor in the background - disconnect_monitor_task = asyncio.create_task(monitor_disconnection()) - - if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")): - yield result - return - - # Automated tasks are handled before to allow mixing them with other conversation commands - cmds_to_rate_limit = [] - is_automated_task = False - if q.startswith("/automated_task"): - is_automated_task = True - q = q.replace("/automated_task", "").lstrip() - cmds_to_rate_limit += [ConversationCommand.AutomatedTask] - - # Extract conversation command from query - conversation_commands = [get_conversation_command(query=q)] - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, - client_application=request.user.client_app, - conversation_id=conversation_id, - title=title, - create_new=body.create_new, - ) - if not conversation: - async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")): - yield result - return - conversation_id = str(conversation.id) - - async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}): - yield event - - agent: Agent | None = None - default_agent = await AgentAdapters.aget_default_agent() - if conversation.agent and conversation.agent != default_agent: - agent = conversation.agent - - if not conversation.agent: - conversation.agent = default_agent - await conversation.asave() - agent = default_agent - - await is_ready_to_chat(user) - user_name = await aget_user_name(user) - location = None - if city or region or country or country_code: - location = LocationData(city=city, region=region, country=country, country_code=country_code) - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - chat_history = conversation.messages - - # If interrupt flag is set, wait for the previous turn to be saved before proceeding - if interrupt_flag: - max_wait_time = 20.0 # seconds - wait_interval = 0.3 # seconds - wait_start = wait_current = time.time() - while wait_current - wait_start < max_wait_time: - # Refresh conversation to check if interrupted message saved to DB - conversation = await ConversationAdapters.aget_conversation_by_user( - user, - client_application=request.user.client_app, - conversation_id=conversation_id, ) + + # Cancel the disconnect monitor task if it is still running + async def cancel_disconnect_monitor(): + if disconnect_monitor_task and not disconnect_monitor_task.done(): + logger.debug(f"Cancelling disconnect monitor task for user {user}") + disconnect_monitor_task.cancel() + try: + await disconnect_monitor_task + except asyncio.CancelledError: + pass + + async def send_event(event_type: ChatEvent, data: str | dict): + nonlocal ttft, train_of_thought + event_delimiter = "␃🔚␗" + if cancellation_event.is_set(): + return + try: + if event_type == ChatEvent.END_LLM_RESPONSE: + collect_telemetry() + elif event_type == ChatEvent.START_LLM_RESPONSE: + ttft = time.perf_counter() - start_time + elif event_type == ChatEvent.STATUS: + train_of_thought.append({"type": event_type.value, "data": data}) + elif event_type == ChatEvent.THOUGHT: + # Append the data to the last thought as thoughts are streamed if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message + len(train_of_thought) > 0 + and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value + and type(train_of_thought[-1]["data"]) == type(data) == str ): - logger.info(f"Detected interrupted message save to conversation {conversation_id}.") - break - await asyncio.sleep(wait_interval) - wait_current = time.time() - - if wait_current - wait_start >= max_wait_time: - logger.warning( - f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context." - ) - - # If interrupted message in DB - if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message - ): - # Populate context from interrupted message - last_message = conversation.messages[-1] - online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} - code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} - compiled_references = [ref.model_dump() for ref in last_message.context or []] - research_results = [ - ResearchIteration(**iter_dict) - for iter_dict in last_message.researchContext or [] - if iter_dict.get("summarizedResult") - ] - operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] - train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] - # Drop the interrupted message from conversation history - chat_history.pop() - logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") - - if conversation_commands == [ConversationCommand.Default]: - try: - chosen_io = await aget_data_sources_and_output_format( - q, - chat_history, - is_automated_task, - user=user, - query_images=uploaded_images, - agent=agent, - query_files=attached_file_context, - tracer=tracer, - ) - except ValueError as e: - logger.error(f"Error getting data sources and output format: {e}. Falling back to default.") - conversation_commands = [ConversationCommand.General] - - conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")] - - # If we're doing research, we don't want to do anything else - if ConversationCommand.Research in conversation_commands: - conversation_commands = [ConversationCommand.Research] - - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"): - yield result - - cmds_to_rate_limit += conversation_commands - for cmd in cmds_to_rate_limit: - try: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - except HTTPException as e: - async for result in send_llm_response(str(e.detail), tracer.get("usage")): - yield result - return - - defiltered_query = defilter_query(q) - file_filters = conversation.file_filters if conversation and conversation.file_filters else [] - - if conversation_commands == [ConversationCommand.Research]: - async for research_result in research( - user=user, - query=defiltered_query, - conversation_id=conversation_id, - conversation_history=chat_history, - previous_iterations=list(research_results), - query_images=uploaded_images, - agent=agent, - send_status_func=partial(send_event, ChatEvent.STATUS), - user_name=user_name, - location=location, - file_filters=file_filters, - query_files=attached_file_context, - tracer=tracer, - cancellation_event=cancellation_event, - ): - if isinstance(research_result, ResearchIteration): - if research_result.summarizedResult: - if research_result.onlineContext: - online_results.update(research_result.onlineContext) - if research_result.codeContext: - code_results.update(research_result.codeContext) - if research_result.context: - compiled_references.extend(research_result.context) - if not research_results or research_results[-1] is not research_result: - research_results.append(research_result) + train_of_thought[-1]["data"] += data else: - yield research_result + train_of_thought.append({"type": event_type.value, "data": data}) - # Track operator results across research and operator iterations - # This relies on two conditions: - # 1. Check to append new (partial) operator results - # Relies on triggering this check on every status updates. - # Status updates cascade up from operator to research to chat api on every step. - # 2. Keep operator results in sync with each research operator step - # Relies on python object references to ensure operator results - # are implicitly kept in sync after the initial append - if ( - research_results - and research_results[-1].operatorContext - and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext) - ): - operator_results.append(research_results[-1].operatorContext) - - # researched_results = await extract_relevant_info(q, researched_results, agent) - if state.verbose > 1: - logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}') - - # Gather Context - ## Extract Document References - if not ConversationCommand.Research in conversation_commands: - try: - async for result in search_documents( - q, - (n or 7), - d, - user, - chat_history, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - query_images=uploaded_images, - agent=agent, - query_files=attached_file_context, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - except Exception as e: - error_message = ( - f"Error searching knowledge base: {e}. Attempting to respond without document references." - ) - logger.error(error_message, exc_info=True) - async for result in send_event( - ChatEvent.STATUS, "Document search failed. I'll try respond without document references" - ): - yield result - - if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): - yield result - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")): - yield result - return - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - ## Gather Online References - if ConversationCommand.Online in conversation_commands: - try: - async for result in search_online( - defiltered_query, - chat_history, - location, - user, - partial(send_event, ChatEvent.STATUS), - custom_filters=[], - max_online_searches=3, - query_images=uploaded_images, - query_files=attached_file_context, - agent=agent, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - online_results = result - except Exception as e: - error_message = f"Error searching online: {e}. Attempting to respond without online results" - logger.warning(error_message) - async for result in send_event( - ChatEvent.STATUS, "Online search failed. I'll try respond without online references" - ): - yield result - - ## Gather Webpage References - if ConversationCommand.Webpage in conversation_commands: - try: - async for result in read_webpages( - defiltered_query, - chat_history, - location, - user, - partial(send_event, ChatEvent.STATUS), - max_webpages_to_read=1, - query_images=uploaded_images, - agent=agent, - query_files=attached_file_context, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages = result - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} - - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): - yield result - except Exception as e: - logger.warning( - f"Error reading webpages: {e}. Attempting to respond without webpage results", + if event_type == ChatEvent.MESSAGE: + yield data + elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: + yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) + except Exception as e: + if not cancellation_event.is_set(): + logger.error( + f"Failed to stream chat API response to {user} on {common.client}: {e}.", exc_info=True, ) - async for result in send_event( - ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references" - ): - yield result + finally: + if not cancellation_event.is_set(): + yield event_delimiter + # Cancel the disconnect monitor task if it is still running + if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE: + await cancel_disconnect_monitor() - ## Gather Code Results - if ConversationCommand.Code in conversation_commands: - try: - context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}" - async for result in run_code( - defiltered_query, - chat_history, - context, - location, - user, - partial(send_event, ChatEvent.STATUS), - query_images=uploaded_images, - agent=agent, - query_files=attached_file_context, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - code_results = result - except ValueError as e: - program_execution_context.append(f"Failed to run code") - logger.warning( - f"Failed to use code tool: {e}. Attempting to respond without code results", - exc_info=True, - ) - if ConversationCommand.Operator in conversation_commands: - try: - async for result in operate_environment( - defiltered_query, - user, - chat_history, - location, - list(operator_results)[-1] if operator_results else None, - query_images=uploaded_images, - query_files=attached_file_context, - send_status_func=partial(send_event, ChatEvent.STATUS), - agent=agent, - cancellation_event=cancellation_event, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - elif isinstance(result, OperatorRun): - if not operator_results or operator_results[-1] is not result: - operator_results.append(result) - # Add webpages visited while operating browser to references - if result.webpages: - if not online_results.get(defiltered_query): - online_results[defiltered_query] = {"webpages": result.webpages} - elif not online_results[defiltered_query].get("webpages"): - online_results[defiltered_query]["webpages"] = result.webpages - else: - online_results[defiltered_query]["webpages"] += result.webpages - except ValueError as e: - program_execution_context.append(f"Browser operation error: {e}") - logger.warning(f"Failed to operate browser with {e}", exc_info=True) - async for result in send_event( - ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately" - ): - yield result - - ## Send Gathered References - unique_online_results = deduplicate_organic_results(online_results) - async for result in send_event( - ChatEvent.REFERENCES, - { - "inferredQueries": inferred_queries, - "context": compiled_references, - "onlineContext": unique_online_results, - "codeContext": code_results, - }, - ): + async def send_llm_response(response: str, usage: dict = None): + # Check if the client is still connected + if cancellation_event.is_set(): + return + # Send Chat Response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + async for result in send_event(ChatEvent.MESSAGE, response): + yield result + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + # Send Usage Metadata once llm interactions are complete + if usage: + async for event in send_event(ChatEvent.USAGE, usage): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): yield result - # Generate Output - ## Generate Image Output - if ConversationCommand.Image in conversation_commands: - async for result in text_to_image( + def collect_telemetry(): + # Gather chat response telemetry + nonlocal chat_metadata + latency = time.perf_counter() - start_time + cmd_set = set([cmd.value for cmd in conversation_commands]) + cost = (tracer.get("usage", {}) or {}).get("cost", 0) + chat_metadata = chat_metadata or {} + chat_metadata["conversation_command"] = cmd_set + chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None + chat_metadata["cost"] = f"{cost:.5f}" + chat_metadata["latency"] = f"{latency:.3f}" + if ttft: + chat_metadata["ttft_latency"] = f"{ttft:.3f}" + logger.info(f"Chat response time to first token: {ttft:.3f} seconds") + logger.info(f"Chat response total time: {latency:.3f} seconds") + logger.info(f"Chat response cost: ${cost:.5f}") + update_telemetry_state( + request=request_obj, + telemetry_type="api", + api="chat", + client=common.client, + user_agent=headers.get("user-agent"), + host=headers.get("host"), + metadata=chat_metadata, + ) + + # Start the disconnect monitor in the background + disconnect_monitor_task = asyncio.create_task(monitor_disconnection()) + + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")): + yield result + return + + # Automated tasks are handled before to allow mixing them with other conversation commands + cmds_to_rate_limit = [] + is_automated_task = False + if q.startswith("/automated_task"): + is_automated_task = True + q = q.replace("/automated_task", "").lstrip() + cmds_to_rate_limit += [ConversationCommand.AutomatedTask] + + # Extract conversation command from query + conversation_commands = [get_conversation_command(query=q)] + + conversation = await ConversationAdapters.aget_conversation_by_user( + user, + client_application=user_scope.client_app, + conversation_id=conversation_id, + title=title, + create_new=body.create_new, + ) + if not conversation: + async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")): + yield result + return + conversation_id = str(conversation.id) + + async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}): + yield event + + agent: Agent | None = None + default_agent = await AgentAdapters.aget_default_agent() + if conversation.agent and conversation.agent != default_agent: + agent = conversation.agent + + if not conversation.agent: + conversation.agent = default_agent + await conversation.asave() + agent = default_agent + + await is_ready_to_chat(user) + user_name = await aget_user_name(user) + location = None + if city or region or country or country_code: + location = LocationData(city=city, region=region, country=country, country_code=country_code) + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + chat_history = conversation.messages + + # If interrupt flag is set, wait for the previous turn to be saved before proceeding + if interrupt_flag: + max_wait_time = 20.0 # seconds + wait_interval = 0.3 # seconds + wait_start = wait_current = time.time() + while wait_current - wait_start < max_wait_time: + # Refresh conversation to check if interrupted message saved to DB + conversation = await ConversationAdapters.aget_conversation_by_user( + user, + client_application=user_scope.client_app, + conversation_id=conversation_id, + ) + if ( + conversation + and conversation.messages + and conversation.messages[-1].by == "khoj" + and not conversation.messages[-1].message + ): + logger.info(f"Detected interrupted message save to conversation {conversation_id}.") + break + await asyncio.sleep(wait_interval) + wait_current = time.time() + + if wait_current - wait_start >= max_wait_time: + logger.warning( + f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context." + ) + + # If interrupted message in DB + if ( + conversation + and conversation.messages + and conversation.messages[-1].by == "khoj" + and not conversation.messages[-1].message + ): + # Populate context from interrupted message + last_message = conversation.messages[-1] + online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} + code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} + compiled_references = [ref.model_dump() for ref in last_message.context or []] + research_results = [ + ResearchIteration(**iter_dict) + for iter_dict in last_message.researchContext or [] + if iter_dict.get("summarizedResult") + ] + operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] + train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] + # Drop the interrupted message from conversation history + chat_history.pop() + logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") + + if conversation_commands == [ConversationCommand.Default]: + try: + chosen_io = await aget_data_sources_and_output_format( + q, + chat_history, + is_automated_task, + user=user, + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ) + except ValueError as e: + logger.error(f"Error getting data sources and output format: {e}. Falling back to default.") + conversation_commands = [ConversationCommand.General] + + conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")] + + # If we're doing research, we don't want to do anything else + if ConversationCommand.Research in conversation_commands: + conversation_commands = [ConversationCommand.Research] + + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"): + yield result + + cmds_to_rate_limit += conversation_commands + for cmd in cmds_to_rate_limit: + try: + await conversation_command_rate_limiter.update_and_check_if_valid(request_obj, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + except HTTPException as e: + async for result in send_llm_response(str(e.detail), tracer.get("usage")): + yield result + return + + defiltered_query = defilter_query(q) + file_filters = conversation.file_filters if conversation and conversation.file_filters else [] + + if conversation_commands == [ConversationCommand.Research]: + async for research_result in research( + user=user, + query=defiltered_query, + conversation_id=conversation_id, + conversation_history=chat_history, + previous_iterations=list(research_results), + query_images=uploaded_images, + agent=agent, + send_status_func=partial(send_event, ChatEvent.STATUS), + user_name=user_name, + location=location, + file_filters=file_filters, + query_files=attached_file_context, + tracer=tracer, + cancellation_event=cancellation_event, + ): + if isinstance(research_result, ResearchIteration): + if research_result.summarizedResult: + if research_result.onlineContext: + online_results.update(research_result.onlineContext) + if research_result.codeContext: + code_results.update(research_result.codeContext) + if research_result.context: + compiled_references.extend(research_result.context) + if not research_results or research_results[-1] is not research_result: + research_results.append(research_result) + else: + yield research_result + + # Track operator results across research and operator iterations + # This relies on two conditions: + # 1. Check to append new (partial) operator results + # Relies on triggering this check on every status updates. + # Status updates cascade up from operator to research to chat api on every step. + # 2. Keep operator results in sync with each research operator step + # Relies on python object references to ensure operator results + # are implicitly kept in sync after the initial append + if ( + research_results + and research_results[-1].operatorContext + and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext) + ): + operator_results.append(research_results[-1].operatorContext) + + # researched_results = await extract_relevant_info(q, researched_results, agent) + if state.verbose > 1: + logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}') + + # Gather Context + ## Extract Document References + if not ConversationCommand.Research in conversation_commands: + try: + async for result in search_documents( + q, + (n or 7), + d, + user, + chat_history, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + except Exception as e: + error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references." + logger.error(error_message, exc_info=True) + async for result in send_event( + ChatEvent.STATUS, "Document search failed. I'll try respond without document references" + ): + yield result + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + # Strip only leading # from headings + headings = headings.replace("#", "") + async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): + yield result + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): + async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")): + yield result + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + ## Gather Online References + if ConversationCommand.Online in conversation_commands: + try: + async for result in search_online( + defiltered_query, + chat_history, + location, + user, + partial(send_event, ChatEvent.STATUS), + custom_filters=[], + max_online_searches=3, + query_images=uploaded_images, + query_files=attached_file_context, + agent=agent, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + online_results = result + except Exception as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_event( + ChatEvent.STATUS, "Online search failed. I'll try respond without online references" + ): + yield result + + ## Gather Webpage References + if ConversationCommand.Webpage in conversation_commands: + try: + async for result in read_webpages( + defiltered_query, + chat_history, + location, + user, + partial(send_event, ChatEvent.STATUS), + max_webpages_to_read=1, + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): + yield result + except Exception as e: + logger.warning( + f"Error reading webpages: {e}. Attempting to respond without webpage results", + exc_info=True, + ) + async for result in send_event( + ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references" + ): + yield result + + ## Gather Code Results + if ConversationCommand.Code in conversation_commands: + try: + context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}" + async for result in run_code( + defiltered_query, + chat_history, + context, + location, + user, + partial(send_event, ChatEvent.STATUS), + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + code_results = result + except ValueError as e: + program_execution_context.append(f"Failed to run code") + logger.warning( + f"Failed to use code tool: {e}. Attempting to respond without code results", + exc_info=True, + ) + if ConversationCommand.Operator in conversation_commands: + try: + async for result in operate_environment( defiltered_query, user, chat_history, - location_data=location, - references=compiled_references, - online_results=online_results, - send_status_func=partial(send_event, ChatEvent.STATUS), + location, + list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + send_status_func=partial(send_event, ChatEvent.STATUS), + agent=agent, + cancellation_event=cancellation_event, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - else: - generated_image, status_code, improved_image_prompt = result - - inferred_queries.append(improved_image_prompt) - if generated_image is None or status_code != 200: - program_execution_context.append(f"Failed to generate image with {improved_image_prompt}") - async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"): - yield result - else: - generated_images.append(generated_image) - - generated_asset_results["images"] = { - "query": improved_image_prompt, - } - - async for result in send_event( - ChatEvent.GENERATED_ASSETS, - { - "images": [generated_image], - }, - ): - yield result - - if ConversationCommand.Diagram in conversation_commands: - async for result in send_event(ChatEvent.STATUS, f"Creating diagram"): + elif isinstance(result, OperatorRun): + if not operator_results or operator_results[-1] is not result: + operator_results.append(result) + # Add webpages visited while operating browser to references + if result.webpages: + if not online_results.get(defiltered_query): + online_results[defiltered_query] = {"webpages": result.webpages} + elif not online_results[defiltered_query].get("webpages"): + online_results[defiltered_query]["webpages"] = result.webpages + else: + online_results[defiltered_query]["webpages"] += result.webpages + except ValueError as e: + program_execution_context.append(f"Browser operation error: {e}") + logger.warning(f"Failed to operate browser with {e}", exc_info=True) + async for result in send_event( + ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately" + ): yield result - inferred_queries = [] - diagram_description = "" + ## Send Gathered References + unique_online_results = deduplicate_organic_results(online_results) + async for result in send_event( + ChatEvent.REFERENCES, + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "onlineContext": unique_online_results, + "codeContext": code_results, + }, + ): + yield result - async for result in generate_mermaidjs_diagram( - q=defiltered_query, - chat_history=chat_history, - location_data=location, - note_references=compiled_references, - online_results=online_results, - query_images=uploaded_images, - user=user, - agent=agent, - send_status_func=partial(send_event, ChatEvent.STATUS), - query_files=attached_file_context, - tracer=tracer, + # Generate Output + ## Generate Image Output + if ConversationCommand.Image in conversation_commands: + async for result in text_to_image( + defiltered_query, + user, + chat_history, + location_data=location, + references=compiled_references, + online_results=online_results, + send_status_func=partial(send_event, ChatEvent.STATUS), + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + generated_image, status_code, improved_image_prompt = result + + inferred_queries.append(improved_image_prompt) + if generated_image is None or status_code != 200: + program_execution_context.append(f"Failed to generate image with {improved_image_prompt}") + async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"): + yield result + else: + generated_images.append(generated_image) + + generated_asset_results["images"] = { + "query": improved_image_prompt, + } + + async for result in send_event( + ChatEvent.GENERATED_ASSETS, + { + "images": [generated_image], + }, ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - better_diagram_description_prompt, mermaidjs_diagram_description = result - if better_diagram_description_prompt and mermaidjs_diagram_description: - inferred_queries.append(better_diagram_description_prompt) - diagram_description = mermaidjs_diagram_description + yield result - generated_mermaidjs_diagram = diagram_description - - generated_asset_results["diagrams"] = { - "query": better_diagram_description_prompt, - } - - async for result in send_event( - ChatEvent.GENERATED_ASSETS, - { - "mermaidjsDiagram": mermaidjs_diagram_description, - }, - ): - yield result - else: - error_message = "Failed to generate diagram. Please try again later." - program_execution_context.append( - prompts.failed_diagram_generation.format( - attempted_diagram=better_diagram_description_prompt - ) - ) - - async for result in send_event(ChatEvent.STATUS, error_message): - yield result - - # Check if the user has disconnected - if cancellation_event.is_set(): - logger.debug(f"Stopping LLM response to user {user} on {common.client} client.") - # Cancel the disconnect monitor task if it is still running - await cancel_disconnect_monitor() - return - - ## Generate Text Output - async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): + if ConversationCommand.Diagram in conversation_commands: + async for result in send_event(ChatEvent.STATUS, f"Creating diagram"): yield result - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - chat_history, - conversation, - compiled_references, - online_results, - code_results, - operator_results, - research_results, - user, - location, - user_name, - uploaded_images, - attached_file_context, - generated_files, - program_execution_context, - generated_asset_results, - is_subscribed, - tracer, - ) + inferred_queries = [] + diagram_description = "" - full_response = "" - async for item in llm_response: - # Should not happen with async generator. Skip. - if item is None or not isinstance(item, ResponseWithThought): - logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.") - continue - if cancellation_event.is_set(): - break - message = item.text - full_response += message if message else "" - if item.thought: - async for result in send_event(ChatEvent.THOUGHT, item.thought): - yield result - continue + async for result in generate_mermaidjs_diagram( + q=defiltered_query, + chat_history=chat_history, + location_data=location, + note_references=compiled_references, + online_results=online_results, + query_images=uploaded_images, + user=user, + agent=agent, + send_status_func=partial(send_event, ChatEvent.STATUS), + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + better_diagram_description_prompt, mermaidjs_diagram_description = result + if better_diagram_description_prompt and mermaidjs_diagram_description: + inferred_queries.append(better_diagram_description_prompt) + diagram_description = mermaidjs_diagram_description - # Start sending response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result + generated_mermaidjs_diagram = diagram_description - try: - async for result in send_event(ChatEvent.MESSAGE, message): - yield result - except Exception as e: - if not cancellation_event.is_set(): - logger.warning(f"Error during streaming. Stopping send: {e}") - break + generated_asset_results["diagrams"] = { + "query": better_diagram_description_prompt, + } - # Save conversation once finish streaming - asyncio.create_task( - save_to_conversation_log( - q, - chat_response=full_response, - user=user, - chat_history=chat_history, - compiled_references=compiled_references, - online_results=online_results, - code_results=code_results, - operator_results=operator_results, - research_results=research_results, - inferred_queries=inferred_queries, - client_application=request.user.client_app, - conversation_id=str(conversation.id), - query_images=uploaded_images, - train_of_thought=train_of_thought, - raw_query_files=raw_query_files, - generated_images=generated_images, - raw_generated_files=generated_files, - generated_mermaidjs_diagram=generated_mermaidjs_diagram, - tracer=tracer, - ) - ) + async for result in send_event( + ChatEvent.GENERATED_ASSETS, + { + "mermaidjsDiagram": mermaidjs_diagram_description, + }, + ): + yield result + else: + error_message = "Failed to generate diagram. Please try again later." + program_execution_context.append( + prompts.failed_diagram_generation.format(attempted_diagram=better_diagram_description_prompt) + ) - # Signal end of LLM response after the loop finishes - if not cancellation_event.is_set(): - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - # Send Usage Metadata once llm interactions are complete - if tracer.get("usage"): - async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): - yield event - async for result in send_event(ChatEvent.END_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") + async for result in send_event(ChatEvent.STATUS, error_message): + yield result + # Check if the user has disconnected + if cancellation_event.is_set(): + logger.debug(f"Stopping LLM response to user {user} on {common.client} client.") # Cancel the disconnect monitor task if it is still running await cancel_disconnect_monitor() + return - ## Stream Text Response - if stream: - return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain") - ## Non-Streaming Text Response + ## Generate Text Output + async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): + yield result + + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + chat_history, + conversation, + compiled_references, + online_results, + code_results, + operator_results, + research_results, + user, + location, + user_name, + uploaded_images, + attached_file_context, + generated_files, + program_execution_context, + generated_asset_results, + is_subscribed, + tracer, + ) + + full_response = "" + async for item in llm_response: + # Should not happen with async generator. Skip. + if item is None or not isinstance(item, ResponseWithThought): + logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.") + continue + if cancellation_event.is_set(): + break + message = item.text + full_response += message if message else "" + if item.thought: + async for result in send_event(ChatEvent.THOUGHT, item.thought): + yield result + continue + + # Start sending response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + + try: + async for result in send_event(ChatEvent.MESSAGE, message): + yield result + except Exception as e: + if not cancellation_event.is_set(): + logger.warning(f"Error during streaming. Stopping send: {e}") + break + + # Save conversation once finish streaming + asyncio.create_task( + save_to_conversation_log( + q, + chat_response=full_response, + user=user, + chat_history=chat_history, + compiled_references=compiled_references, + online_results=online_results, + code_results=code_results, + operator_results=operator_results, + research_results=research_results, + inferred_queries=inferred_queries, + client_application=user_scope.client_app, + conversation_id=str(conversation.id), + query_images=uploaded_images, + train_of_thought=train_of_thought, + raw_query_files=raw_query_files, + generated_images=generated_images, + raw_generated_files=generated_files, + generated_mermaidjs_diagram=generated_mermaidjs_diagram, + tracer=tracer, + ) + ) + + # Signal end of LLM response after the loop finishes + if not cancellation_event.is_set(): + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + # Send Usage Metadata once llm interactions are complete + if tracer.get("usage"): + async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") + + # Cancel the disconnect monitor task if it is still running + await cancel_disconnect_monitor() + + +@api_chat.websocket("/ws") +@requires(["authenticated"]) +async def chat_ws( + websocket: WebSocket, + common: CommonQueryParams, +): + await websocket.accept() + + # Initialize rate limiters + rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") + rate_limiter_per_day = ApiUserRateLimiter( + requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day" + ) + image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20) + + current_task = None + + try: + while True: + data = await websocket.receive_json() + + # Handle regular chat messages + # Handle regular chat messages - ensure data has required fields + if "q" not in data: + await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"})) + continue + + body = ChatRequestBody(**data) + + # Apply rate limiting manually + try: + rate_limiter_per_minute.check_websocket(websocket) + rate_limiter_per_day.check_websocket(websocket) + image_rate_limiter.check_websocket(websocket, body) + except HTTPException as e: + await websocket.send_text(json.dumps({"error": e.detail})) + continue + + # Cancel any ongoing task before starting a new one + if current_task and not current_task.done(): + current_task.cancel() + try: + await current_task + except asyncio.CancelledError: + pass + + # Create a new task for processing the chat request + current_task = asyncio.create_task(process_chat_request(websocket, body, common)) + + except WebSocketDisconnect: + logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}") + if current_task and not current_task.done(): + current_task.cancel() + except Exception as e: + logger.error(f"Error in websocket chat: {e}", exc_info=True) + if current_task and not current_task.done(): + current_task.cancel() + await websocket.close(code=1011, reason="Internal Server Error") + + +async def process_chat_request( + websocket: WebSocket, + body: ChatRequestBody, + common: CommonQueryParams, +): + """Process a single chat request with interrupt support""" + try: + # Since we are using websockets, we can ignore the stream parameter and always stream + response_iterator = event_generator( + body, + websocket.scope["user"], + common, + websocket.headers, + websocket, + ) + async for event in response_iterator: + await websocket.send_text(event) + except asyncio.CancelledError: + logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}") + raise + except Exception as e: + logger.error(f"Error processing chat request: {e}", exc_info=True) + await websocket.send_text(json.dumps({"error": "Internal server error"})) + raise + + +@api_chat.post("") +@requires(["authenticated"]) +async def chat( + request: Request, + common: CommonQueryParams, + body: ChatRequestBody, + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + ), + image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)), +): + response_iterator = event_generator( + body, + request.user, + common, + request.headers, + request, + ) + + # Stream Text Response + if body.stream: + return StreamingResponse(response_iterator, media_type="text/plain") + # Non-Streaming Text Response else: - response_iterator = event_generator(q, images=raw_images) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 48fa4252..9bfdb155 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -33,7 +33,7 @@ from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from django.utils import timezone as django_timezone -from fastapi import Depends, Header, HTTPException, Request, UploadFile +from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket from pydantic import BaseModel, EmailStr, Field from starlette.authentication import has_required_scope from starlette.requests import URL @@ -1936,6 +1936,53 @@ class ApiUserRateLimiter: # Add the current request to the cache UserRequests.objects.create(user=user, slug=self.slug) + def check_websocket(self, websocket: WebSocket): + """WebSocket-specific rate limiting method""" + # Rate limiting disabled if billing is disabled + if state.billing_enabled is False: + return + + # Rate limiting is disabled if user unauthenticated. + if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated: + return + + user: KhojUser = websocket.scope["user"].object + subscribed = has_required_scope(websocket, ["premium"]) + + # Remove requests outside of the time window + cutoff = django_timezone.now() - timedelta(seconds=self.window) + count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count() + + # Check if the user has exceeded the rate limit + if subscribed and count_requests >= self.subscribed_requests: + logger.info( + f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}." + ) + raise HTTPException( + status_code=429, + detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?", + ) + if not subscribed and count_requests >= self.requests: + if self.requests >= self.subscribed_requests: + logger.info( + f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}." + ) + raise HTTPException( + status_code=429, + detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?", + ) + + logger.info( + f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}." + ) + raise HTTPException( + status_code=429, + detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?", + ) + + # Add the current request to the cache + UserRequests.objects.create(user=user, slug=self.slug) + class ApiImageRateLimiter: def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10): @@ -1983,6 +2030,47 @@ class ApiImageRateLimiter: detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.", ) + def check_websocket(self, websocket: WebSocket, body: ChatRequestBody): + """WebSocket-specific image rate limiting method""" + if state.billing_enabled is False: + return + + # Rate limiting is disabled if user unauthenticated. + if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated: + return + + if not body.images: + return + + # Check number of images + if len(body.images) > self.max_images: + logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.") + raise HTTPException( + status_code=429, + detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.", + ) + + # Check total size of images + total_size_mb = 0.0 + for image in body.images: + # Unquote the image in case it's URL encoded + image = unquote(image) + # Assuming the image is a base64 encoded string + # Remove the data:image/jpeg;base64, part if present + if "," in image: + image = image.split(",", 1)[1] + + # Decode base64 to get the actual size + image_bytes = base64.b64decode(image) + total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB + + if total_size_mb > self.max_combined_size_mb: + logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.") + raise HTTPException( + status_code=429, + detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.", + ) + class ConversationCommandRateLimiter: def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): @@ -1991,7 +2079,7 @@ class ConversationCommandRateLimiter: self.subscribed_rate_limit = subscribed_rate_limit self.restricted_commands = [ConversationCommand.Research] - async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): + async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand): if state.billing_enabled is False: return From 9f0eff65418443d0bbdcc4478ab7d0019be0bb1e Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 12 Jul 2025 11:19:41 -0700 Subject: [PATCH 2/6] Handle passing interrupt messages from api to chat actors on server --- src/khoj/processor/conversation/utils.py | 1 + src/khoj/processor/operator/__init__.py | 12 +++++- src/khoj/routers/api_chat.py | 49 +++++++++--------------- src/khoj/routers/helpers.py | 11 ++++++ src/khoj/routers/research.py | 22 ++++++++++- src/khoj/utils/rawconfig.py | 1 - 6 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 011e8045..b4dd5e9c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -384,6 +384,7 @@ class ChatEvent(Enum): METADATA = "metadata" USAGE = "usage" END_RESPONSE = "end_response" + INTERRUPT = "interrupt" def message_to_log( diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 9b4ad80f..c07dec90 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -7,6 +7,7 @@ from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation.utils import ( + AgentMessage, OperatorRun, construct_chat_history_for_operator, ) @@ -22,7 +23,7 @@ from khoj.processor.operator.operator_environment_base import ( ) from khoj.processor.operator.operator_environment_browser import BrowserEnvironment from khoj.processor.operator.operator_environment_computer import ComputerEnvironment -from khoj.routers.helpers import ChatEvent +from khoj.routers.helpers import ChatEvent, get_message_from_queue from khoj.utils.helpers import timer from khoj.utils.rawconfig import LocationData @@ -42,6 +43,7 @@ async def operate_environment( agent: Agent = None, query_files: str = None, # TODO: Handle query files cancellation_event: Optional[asyncio.Event] = None, + interrupt_queue: Optional[asyncio.Queue] = None, tracer: dict = {}, ): response, user_input_message = None, None @@ -140,6 +142,14 @@ async def operate_environment( logger.debug(f"{environment_type.value} operator cancelled by client disconnect") break + # Add interrupt query to current operator run + if interrupt_query := get_message_from_queue(interrupt_queue): + # Add the interrupt query as a new user message to the research conversation history + logger.info(f"Continuing operator run with the new instruction: {interrupt_query}") + operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query)) + async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"): + yield result + iterations += 1 # 1. Get current environment state diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 04b50bfd..e602ab2c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -672,6 +672,7 @@ async def event_generator( common: CommonQueryParams, headers: Headers, request_obj: Request | WebSocket, + interrupt_queue: asyncio.Queue = None, ): # Access the parameters from the body q = body.q @@ -688,7 +689,6 @@ async def event_generator( timezone = body.timezone raw_images = body.images raw_query_files = body.files - interrupt_flag = body.interrupt start_time = time.perf_counter() ttft = None @@ -955,34 +955,6 @@ async def event_generator( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") chat_history = conversation.messages - # If interrupt flag is set, wait for the previous turn to be saved before proceeding - if interrupt_flag: - max_wait_time = 20.0 # seconds - wait_interval = 0.3 # seconds - wait_start = wait_current = time.time() - while wait_current - wait_start < max_wait_time: - # Refresh conversation to check if interrupted message saved to DB - conversation = await ConversationAdapters.aget_conversation_by_user( - user, - client_application=user_scope.client_app, - conversation_id=conversation_id, - ) - if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message - ): - logger.info(f"Detected interrupted message save to conversation {conversation_id}.") - break - await asyncio.sleep(wait_interval) - wait_current = time.time() - - if wait_current - wait_start >= max_wait_time: - logger.warning( - f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context." - ) - # If interrupted message in DB if ( conversation @@ -1061,6 +1033,7 @@ async def event_generator( query_files=attached_file_context, tracer=tracer, cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, ): if isinstance(research_result, ResearchIteration): if research_result.summarizedResult: @@ -1491,13 +1464,25 @@ async def chat_ws( ) image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20) + # Shared interrupt queue for communicating interrupts to ongoing research + interrupt_queue: asyncio.Queue = asyncio.Queue() current_task = None try: while True: data = await websocket.receive_json() - # Handle regular chat messages + # Check if this is an interrupt message + if data.get("type") == "interrupt": + if current_task and not current_task.done(): + # Send interrupt signal to the ongoing task + await interrupt_queue.put(data.get("query", "")) + logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}") + await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"})) + else: + logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}") + continue + # Handle regular chat messages - ensure data has required fields if "q" not in data: await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"})) @@ -1523,7 +1508,7 @@ async def chat_ws( pass # Create a new task for processing the chat request - current_task = asyncio.create_task(process_chat_request(websocket, body, common)) + current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue)) except WebSocketDisconnect: logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}") @@ -1540,6 +1525,7 @@ async def process_chat_request( websocket: WebSocket, body: ChatRequestBody, common: CommonQueryParams, + interrupt_queue: asyncio.Queue, ): """Process a single chat request with interrupt support""" try: @@ -1550,6 +1536,7 @@ async def process_chat_request( common, websocket.headers, websocket, + interrupt_queue, ) async for event in response_iterator: await websocket.send_text(event) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 9bfdb155..41189740 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -2600,6 +2600,17 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict } +def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]: + """Get any message in queue if available.""" + if not queue: + return None + try: + # Non-blocking check for message in the queue + return queue.get_nowait() + except asyncio.QueueEmpty: + return None + + def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): user_picture = request.session.get("user", {}).get("picture") is_active = has_required_scope(request, ["premium"]) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 5e476d4b..050fd3f8 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import ( ResearchIteration, ToolCall, construct_iteration_history, + construct_structured_message, construct_tool_chat_history, load_complex_json, ) @@ -24,6 +25,7 @@ from khoj.processor.tools.run_code import run_code from khoj.routers.helpers import ( ChatEvent, generate_summary_from_files, + get_message_from_queue, grep_files, list_files, search_documents, @@ -74,7 +76,7 @@ async def apick_next_tool( ): previous_iteration = previous_iterations[-1] yield ResearchIteration( - query=query, + query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore context=previous_iteration.context, onlineContext=previous_iteration.onlineContext, codeContext=previous_iteration.codeContext, @@ -221,6 +223,7 @@ async def research( tracer: dict = {}, query_files: str = None, cancellation_event: Optional[asyncio.Event] = None, + interrupt_queue: Optional[asyncio.Queue] = None, ): max_document_searches = 7 max_online_searches = 3 @@ -241,6 +244,22 @@ async def research( logger.debug(f"Research cancelled. User {user} disconnected client.") break + # Update the query for the current research iteration + if interrupt_query := get_message_from_queue(interrupt_queue): + # Add the interrupt query as a new user message to the research conversation history + logger.info( + f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}" + ) + previous_iterations_history = construct_iteration_history( + previous_iterations, query, query_images, query_files + ) + research_conversation_history += previous_iterations_history + query = interrupt_query + previous_iterations = [] + + async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"): + yield result + online_results: Dict = dict() code_results: Dict = dict() document_results: List[Dict[str, str]] = [] @@ -428,6 +447,7 @@ async def research( agent=agent, query_files=query_files, cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index e3662db5..df7f3334 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -168,7 +168,6 @@ class ChatRequestBody(BaseModel): images: Optional[list[str]] = None files: Optional[list[FileAttachment]] = [] create_new: Optional[bool] = False - interrupt: Optional[bool] = False class Entry: From eaed0c839e88f8bfc81197d27e442ec5121addb6 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 18 Jun 2025 16:46:10 -0700 Subject: [PATCH 3/6] Use websocket chat api endpoint to communicate from web app - Use websocket library to handle setup, reconnection from web app Use react-use-websocket library to handle websocket connection and reconnection logic. Previously connection wasn't re-established on disconnects. - Send interrupt messages with ws to update research, operator trajectory Previously we were using the abort and send new POST /api/chat mechanism. But now we can use the websocket's bi-directional messaging capability to send users messages in the middle of a research, operator run. This change should 1. Allow for a faster, more interactive interruption to shift the research direction without breaking the conversation flow. As previously we were using the DB to communicate interrupts across workers, this would take time and feel sluggish on the UX. 2. Be a more robust interrupt mechanism that'll work in multi worker setups. As same worker is interacted with to send interrupt messages instead of potentially new worker receiving the POST /api/chat with the interrupt user message. On the server we're using an asyncio Queue to pass messages down from websocket api to researcher via event generator. This can be extended to pass to other iterative agents like operator. --- src/interface/web/app/chat/page.tsx | 287 ++++++++++-------- .../chatInputArea/chatInputArea.tsx | 8 +- src/interface/web/package.json | 1 + src/interface/web/yarn.lock | 6 + 4 files changed, 172 insertions(+), 130 deletions(-) diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index 13315d4b..8ef97b8d 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -1,7 +1,8 @@ "use client"; import styles from "./chat.module.css"; -import React, { Suspense, useEffect, useRef, useState } from "react"; +import React, { Suspense, useCallback, useEffect, useRef, useState } from "react"; +import useWebSocket from "react-use-websocket"; import ChatHistory from "../components/chatHistory/chatHistory"; import { useSearchParams } from "next/navigation"; @@ -45,7 +46,7 @@ interface ChatBodyDataProps { isMobileWidth?: boolean; isLoggedIn: boolean; setImages: (images: string[]) => void; - setTriggeredAbort: (triggeredAbort: boolean) => void; + setTriggeredAbort: (triggeredAbort: boolean, newMessage?: string) => void; isChatSideBarOpen: boolean; setIsChatSideBarOpen: (open: boolean) => void; isActive?: boolean; @@ -205,10 +206,10 @@ export default function Chat() { const [uploadedFiles, setUploadedFiles] = useState(undefined); const [images, setImages] = useState([]); - const [abortMessageStreamController, setAbortMessageStreamController] = - useState(null); const [triggeredAbort, setTriggeredAbort] = useState(false); - const [shouldSendWithInterrupt, setShouldSendWithInterrupt] = useState(false); + const [interruptMessage, setInterruptMessage] = useState(""); + const bufferRef = useRef(""); + const idleTimerRef = useRef(null); const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || { locationData: { @@ -222,6 +223,107 @@ export default function Chat() { } = useAuthenticatedData(); const isMobileWidth = useIsMobileWidth(); const [isChatSideBarOpen, setIsChatSideBarOpen] = useState(false); + const [socketUrl, setSocketUrl] = useState(null); + + const disconnectFromServer = useCallback(() => { + if (idleTimerRef.current) { + clearTimeout(idleTimerRef.current); + } + setSocketUrl(null); + console.log("WebSocket disconnected due to inactivity."); + }, []); + + const resetIdleTimer = useCallback(() => { + const idleTimeout = 10 * 60 * 1000; // 10 minutes + if (idleTimerRef.current) { + clearTimeout(idleTimerRef.current); + } + idleTimerRef.current = setTimeout(disconnectFromServer, idleTimeout); + }, [disconnectFromServer]); + + const { sendMessage, lastMessage } = useWebSocket(socketUrl, { + share: true, + shouldReconnect: (closeEvent) => true, + reconnectAttempts: 10, + // reconnect using exponential backoff with jitter + reconnectInterval: (attemptNumber) => { + const baseDelay = 1000 * Math.pow(2, attemptNumber); + const jitter = Math.random() * 1000; // Add jitter up to 1s + return Math.min(baseDelay + jitter, 20000); // Cap backoff at 20s + }, + onOpen: () => { + console.log("WebSocket connection established."); + resetIdleTimer(); + }, + onClose: () => { + console.log("WebSocket connection closed."); + if (idleTimerRef.current) { + clearTimeout(idleTimerRef.current); + } + }, + }); + + useEffect(() => { + if (lastMessage !== null) { + resetIdleTimer(); + // Check if this is a control message (JSON) rather than a streaming event + try { + const controlMessage = JSON.parse(lastMessage.data); + if (controlMessage.type === "interrupt_acknowledged") { + console.log("Interrupt acknowledged by server"); + setSocketUrl(null); + setProcessQuerySignal(false); + return; + } + if (controlMessage.error) { + console.error("WebSocket error:", controlMessage.error); + return; + } + } catch { + // Not a JSON control message, process as streaming event + } + + const eventDelimiter = "␃🔚␗"; + bufferRef.current += lastMessage.data; + + let newEventIndex; + while ((newEventIndex = bufferRef.current.indexOf(eventDelimiter)) !== -1) { + const eventChunk = bufferRef.current.slice(0, newEventIndex); + bufferRef.current = bufferRef.current.slice(newEventIndex + eventDelimiter.length); + if (eventChunk) { + setMessages((prevMessages) => { + const newMessages = [...prevMessages]; + const currentMessage = newMessages[newMessages.length - 1]; + if (!currentMessage || currentMessage.completed) { + return prevMessages; + } + + const { context, onlineContext, codeContext } = processMessageChunk( + eventChunk, + currentMessage, + currentMessage.context || [], + currentMessage.onlineContext || {}, + currentMessage.codeContext || {}, + ); + + // Update the current message with the new reference data + currentMessage.context = context; + currentMessage.onlineContext = onlineContext; + currentMessage.codeContext = codeContext; + + if (currentMessage.completed) { + setQueryToProcess(""); + setProcessQuerySignal(false); + setImages([]); + if (conversationId) generateNewTitle(conversationId, setTitle); + } + + return newMessages; + }); + } + } + } + }, [lastMessage, setMessages]); useEffect(() => { fetch("/api/chat/options") @@ -241,14 +343,41 @@ export default function Chat() { welcomeConsole(); }, []); + const handleTriggeredAbort = (value: boolean, newMessage?: string) => { + if (value) { + setInterruptMessage(newMessage || ""); + } + setTriggeredAbort(value); + }; + useEffect(() => { if (triggeredAbort) { - abortMessageStreamController?.abort(); - handleAbortedMessage(); - setShouldSendWithInterrupt(true); - setTriggeredAbort(false); + sendMessage( + JSON.stringify({ + type: "interrupt", + query: interruptMessage, + }), + ); + console.log("Sent interrupt message via WebSocket:", interruptMessage); + + // Update the current message with the new query but keep it in processing state + const messageToProcess = interruptMessage || queryToProcess; + setMessages((prevMessages) => { + const newMessages = [...prevMessages]; + const currentMessage = newMessages[newMessages.length - 1]; + if (currentMessage && !currentMessage.completed) { + currentMessage.rawQuery = messageToProcess; + currentMessage.completed = !!interruptMessage; + } + return newMessages; + }); + + // Update the query being processed + setQueryToProcess(messageToProcess); + setTriggeredAbort(!!interruptMessage); + setInterruptMessage(""); } - }, [triggeredAbort]); + }, [triggeredAbort, interruptMessage, queryToProcess, sendMessage]); useEffect(() => { if (queryToProcess) { @@ -266,7 +395,6 @@ export default function Chat() { }; setMessages((prevMessages) => [...prevMessages, newStreamMessage]); setProcessQuerySignal(true); - setAbortMessageStreamController(new AbortController()); } }, [queryToProcess]); @@ -280,70 +408,19 @@ export default function Chat() { } }, [processQuerySignal, locationDataLoading]); - async function readChatStream(response: Response) { - if (!response.ok) throw new Error(response.statusText); - if (!response.body) throw new Error("Response body is null"); + useEffect(() => { + if (!conversationId) return; - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - const eventDelimiter = "␃🔚␗"; - let buffer = ""; + const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; + const wsUrl = `${protocol}//${window.location.host}/api/chat/ws?client=web`; + setSocketUrl(wsUrl); - // Track context used for chat response - let context: Context[] = []; - let onlineContext: OnlineContext = {}; - let codeContext: CodeContext = {}; - - while (true) { - const { done, value } = await reader.read(); - if (done) { - setQueryToProcess(""); - setProcessQuerySignal(false); - setImages([]); - - if (conversationId) generateNewTitle(conversationId, setTitle); - - break; + return () => { + if (idleTimerRef.current) { + clearTimeout(idleTimerRef.current); } - - const chunk = decoder.decode(value, { stream: true }); - buffer += chunk; - - let newEventIndex; - while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) { - const event = buffer.slice(0, newEventIndex); - buffer = buffer.slice(newEventIndex + eventDelimiter.length); - if (event) { - const currentMessage = messages.find((message) => !message.completed); - - if (!currentMessage) { - console.error("No current message found"); - return; - } - - // Track context used for chat response. References are rendered at the end of the chat - ({ context, onlineContext, codeContext } = processMessageChunk( - event, - currentMessage, - context, - onlineContext, - codeContext, - )); - - setMessages([...messages]); - } - } - } - } - - function handleAbortedMessage() { - const currentMessage = messages.find((message) => !message.completed); - if (!currentMessage) return; - - currentMessage.completed = true; - setMessages([...messages]); - setProcessQuerySignal(false); - } + }; + }, [conversationId]); async function chat() { localStorage.removeItem("message"); @@ -351,12 +428,19 @@ export default function Chat() { setProcessQuerySignal(false); return; } - const chatAPI = "/api/chat?client=web"; + + // Re-establish WebSocket connection if disconnected + resetIdleTimer(); + if (!socketUrl) { + const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; + const wsUrl = `${protocol}//${window.location.host}/api/chat/ws?client=web`; + setSocketUrl(wsUrl); + } + const chatAPIBody = { q: queryToProcess, conversation_id: conversationId, stream: true, - interrupt: shouldSendWithInterrupt, ...(locationData && { city: locationData.city, region: locationData.region, @@ -368,58 +452,7 @@ export default function Chat() { ...(uploadedFiles && { files: uploadedFiles }), }; - // Reset the flag after using it - setShouldSendWithInterrupt(false); - - const response = await fetch(chatAPI, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(chatAPIBody), - signal: abortMessageStreamController?.signal, - }); - - try { - await readChatStream(response); - } catch (err) { - let apiError; - try { - apiError = await response.json(); - } catch (err) { - // Error reading API error response - apiError = { - streamError: "Error reading API error response stream. Expected JSON response.", - }; - } - console.error(apiError); - // Retrieve latest message being processed - const currentMessage = messages.find((message) => !message.completed); - if (!currentMessage) return; - - // Render error message as current message - const errorMessage = (err as Error).message; - const errorName = (err as Error).name; - if (errorMessage.includes("Error in input stream")) - currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`; - else if (apiError.streamError) { - currentMessage.rawResponse = `Umm, not sure what just happened but I lost my train of thought. Could you try again or ask my developers to look into this if the issue persists? They can be contacted at the Khoj Github, Discord or team@khoj.dev.`; - } else if (response.status === 429) { - "detail" in apiError - ? (currentMessage.rawResponse = `${apiError.detail}`) - : (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`); - } else if (errorName === "AbortError") { - currentMessage.rawResponse = `I've stopped processing this message. If you'd like to continue, please send the message again.`; - } else { - currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`; - } - - // Complete message streaming teardown properly - currentMessage.completed = true; - setMessages([...messages]); - setQueryToProcess(""); - setProcessQuerySignal(false); - } + sendMessage(JSON.stringify(chatAPIBody)); } const handleConversationIdChange = (newConversationId: string) => { @@ -522,7 +555,7 @@ export default function Chat() { isMobileWidth={isMobileWidth} onConversationIdChange={handleConversationIdChange} setImages={setImages} - setTriggeredAbort={setTriggeredAbort} + setTriggeredAbort={handleTriggeredAbort} isChatSideBarOpen={isChatSideBarOpen} setIsChatSideBarOpen={setIsChatSideBarOpen} isActive={authenticatedData?.is_active} diff --git a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx index e2234701..408eea1a 100644 --- a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx +++ b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx @@ -82,7 +82,7 @@ interface ChatInputProps { isLoggedIn: boolean; agentColor?: string; isResearchModeEnabled?: boolean; - setTriggeredAbort: (value: boolean) => void; + setTriggeredAbort: (value: boolean, newMessage?: string) => void; prefillMessage?: string; focus?: ChatInputFocus; } @@ -189,9 +189,11 @@ export const ChatInputArea = forwardRef((pr return; } - // If currently processing, trigger abort first + // If currently processing, handle interrupt first if (props.sendDisabled) { - props.setTriggeredAbort(true); + props.setTriggeredAbort(true, message.trim()); + setMessage(""); // Clear the input + return; // Don't continue with regular message sending } if (imageUploaded) { diff --git a/src/interface/web/package.json b/src/interface/web/package.json index 72da0782..31d8504b 100644 --- a/src/interface/web/package.json +++ b/src/interface/web/package.json @@ -71,6 +71,7 @@ "react": "^18", "react-dom": "^18", "react-hook-form": "^7.52.1", + "react-use-websocket": "^4.13.0", "shadcn-ui": "^0.9.0", "swr": "^2.2.5", "tailwind-merge": "^2.3.0", diff --git a/src/interface/web/yarn.lock b/src/interface/web/yarn.lock index 46fc5051..537c50a8 100644 --- a/src/interface/web/yarn.lock +++ b/src/interface/web/yarn.lock @@ -4542,6 +4542,11 @@ react-style-singleton@^2.2.2, react-style-singleton@^2.2.3: get-nonce "^1.0.0" tslib "^2.0.0" +react-use-websocket@^4.13.0: + version "4.13.0" + resolved "https://registry.yarnpkg.com/react-use-websocket/-/react-use-websocket-4.13.0.tgz#9db1dbac6dc8ba2fdc02a5bba06205fbf6406736" + integrity sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw== + react@^18: version "18.3.1" resolved "https://registry.yarnpkg.com/react/-/react-18.3.1.tgz#49ab892009c53933625bd16b2533fc754cab2891" @@ -4894,6 +4899,7 @@ string-argv@^0.3.2: integrity sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q== "string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0: + name string-width-cjs version "4.2.3" resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== From 7b7b1830b7fc79911711f6affd1bb15fe4671a9c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 27 Jun 2025 00:27:47 -0700 Subject: [PATCH 4/6] Make callers only share new messages to append to chat logs - Chat history is retrieved and updated with new messages just before write. This is to reduce chance of message loss due to conflicting writes making last to save to conversation win conflict. - This was problematic artifact of old code. Removing it should reduce conflict surface area. - Interrupts and live chat could hit this issue due to different reasons --- src/khoj/database/adapters/__init__.py | 5 +++-- src/khoj/database/models/__init__.py | 28 ++++++++++++++++++++++++ src/khoj/processor/conversation/utils.py | 11 +++++----- src/khoj/routers/api_chat.py | 17 ++++---------- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index fe53a73c..14eaadfe 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1465,7 +1465,7 @@ class ConversationAdapters: @require_valid_user async def save_conversation( user: KhojUser, - chat_history: List[ChatMessageModel], + new_messages: List[ChatMessageModel], client_application: ClientApplication = None, conversation_id: str = None, user_message: str = None, @@ -1480,7 +1480,8 @@ class ConversationAdapters: await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() ) - conversation_log = {"chat": [msg.model_dump() for msg in chat_history]} + existing_messages = conversation.messages if conversation else [] + conversation_log = {"chat": [msg.model_dump() for msg in existing_messages + new_messages]} cleaned_conversation_log = clean_object_for_db(conversation_log) if conversation: conversation.conversation_log = cleaned_conversation_log diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 8d903338..7f80459a 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -677,6 +677,34 @@ class Conversation(DbBaseModel): continue return validated_messages + async def pop_message(self, interrupted: bool = False) -> Optional[ChatMessageModel]: + """ + Remove and return the last message from the conversation log, persisting the change to the database. + When interrupted is True, we only drop the last message if it was an interrupted message by khoj. + """ + chat_log = self.conversation_log.get("chat", []) + + if not chat_log: + return None + + last_message = chat_log[-1] + is_interrupted_msg = last_message.get("by") == "khoj" and not last_message.get("message") + # When handling an interruption, only pop if the last message is an empty one by khoj. + if interrupted and not is_interrupted_msg: + return None + + # Pop the last message, save the conversation, and then return the message. + popped_message_dict = chat_log.pop() + await self.asave() + + # Try to validate and return the popped message as a Pydantic model + try: + return ChatMessageModel.model_validate(popped_message_dict) + except ValidationError as e: + logger.warning(f"Popped an invalid message from conversation. The removal has been saved. Error: {e}") + # The invalid message was removed and saved, but we can't return a valid model. + return None + class PublicConversation(DbBaseModel): source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b4dd5e9c..6edd84b2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -435,7 +435,6 @@ async def save_to_conversation_log( q: str, chat_response: str, user: KhojUser, - chat_history: List[ChatMessageModel], user_message_time: str = None, compiled_references: List[Dict[str, Any]] = [], online_results: Dict[str, Any] = {}, @@ -481,22 +480,22 @@ async def save_to_conversation_log( khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram try: - updated_conversation = message_to_log( + new_messages = message_to_log( user_message=q, chat_response=chat_response, user_message_metadata=user_message_metadata, khoj_message_metadata=khoj_message_metadata, - chat_history=chat_history, + chat_history=[], ) except ValidationError as e: - updated_conversation = None + new_messages = None logger.error(f"Error constructing chat history: {e}") db_conversation = None - if updated_conversation: + if new_messages: db_conversation = await ConversationAdapters.save_conversation( user, - updated_conversation, + new_messages, client_application=client_application, conversation_id=conversation_id, user_message=q, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e602ab2c..6d444b92 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -738,6 +738,7 @@ async def event_generator( generated_mermaidjs_diagram: str = None generated_asset_results: Dict = dict() program_execution_context: List[str] = [] + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Create a task to monitor for disconnections disconnect_monitor_task = None @@ -757,7 +758,6 @@ async def event_generator( q, chat_response="", user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -772,6 +772,7 @@ async def event_generator( generated_images=generated_images, raw_generated_files=generated_asset_results, generated_mermaidjs_diagram=generated_mermaidjs_diagram, + user_message_time=user_message_time, tracer=tracer, ) ) @@ -789,7 +790,6 @@ async def event_generator( q, chat_response="", user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -804,6 +804,7 @@ async def event_generator( generated_images=generated_images, raw_generated_files=generated_asset_results, generated_mermaidjs_diagram=generated_mermaidjs_diagram, + user_message_time=user_message_time, tracer=tracer, ) ) @@ -952,18 +953,11 @@ async def event_generator( location = None if city or region or country or country_code: location = LocationData(city=city, region=region, country=country, country_code=country_code) - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") chat_history = conversation.messages # If interrupted message in DB - if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message - ): + if last_message := await conversation.pop_message(interrupted=True): # Populate context from interrupted message - last_message = conversation.messages[-1] online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} compiled_references = [ref.model_dump() for ref in last_message.context or []] @@ -974,8 +968,6 @@ async def event_generator( ] operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] - # Drop the interrupted message from conversation history - chat_history.pop() logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") if conversation_commands == [ConversationCommand.Default]: @@ -1414,7 +1406,6 @@ async def event_generator( q, chat_response=full_response, user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, From 0ecd5f497dcd89d7bf67bf1f1b07ba1b3d7ed43f Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 3 Jul 2025 16:05:42 -0700 Subject: [PATCH 5/6] Show more informative title for semantic search train of thought --- src/khoj/routers/api_chat.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 6d444b92..c8c23c8f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1093,10 +1093,14 @@ async def event_generator( yield result if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + distinct_headings = set([d.get("compiled").split("\n")[0] for d in compiled_references if "compiled" in d]) + distinct_files = set([d["file"] for d in compiled_references]) # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): + headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") + async for result in send_event( + ChatEvent.STATUS, + f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}", + ): yield result if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): From b90e2367d5494e6ea718d2638dde93008395d4d0 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 17 Jul 2025 15:36:26 -0700 Subject: [PATCH 6/6] Fix interrupt UX and research when using websocket via web app --- src/interface/web/app/chat/page.tsx | 24 ++++++------- src/khoj/processor/operator/__init__.py | 5 +++ src/khoj/routers/api_chat.py | 45 +++++++++++++++++++------ src/khoj/routers/research.py | 5 +++ 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index 8ef97b8d..796290cf 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -271,11 +271,13 @@ export default function Chat() { const controlMessage = JSON.parse(lastMessage.data); if (controlMessage.type === "interrupt_acknowledged") { console.log("Interrupt acknowledged by server"); - setSocketUrl(null); setProcessQuerySignal(false); return; - } - if (controlMessage.error) { + } else if (controlMessage.type === "interrupt_message_acknowledged") { + console.log("Interrupt message acknowledged by server"); + setProcessQuerySignal(false); + return; + } else if (controlMessage.error) { console.error("WebSocket error:", controlMessage.error); return; } @@ -360,24 +362,20 @@ export default function Chat() { ); console.log("Sent interrupt message via WebSocket:", interruptMessage); - // Update the current message with the new query but keep it in processing state - const messageToProcess = interruptMessage || queryToProcess; + // Mark the last message as completed setMessages((prevMessages) => { const newMessages = [...prevMessages]; const currentMessage = newMessages[newMessages.length - 1]; - if (currentMessage && !currentMessage.completed) { - currentMessage.rawQuery = messageToProcess; - currentMessage.completed = !!interruptMessage; - } + if (currentMessage) currentMessage.completed = true; return newMessages; }); - // Update the query being processed - setQueryToProcess(messageToProcess); - setTriggeredAbort(!!interruptMessage); + // Set the interrupt message as the new query being processed + setQueryToProcess(interruptMessage); + setTriggeredAbort(false); // Always set to false after processing setInterruptMessage(""); } - }, [triggeredAbort, interruptMessage, queryToProcess, sendMessage]); + }, [triggeredAbort, sendMessage]); useEffect(() => { if (queryToProcess) { diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index c07dec90..d7068bff 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -44,6 +44,7 @@ async def operate_environment( query_files: str = None, # TODO: Handle query files cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, + abort_message: Optional[str] = "␃🔚␗", tracer: dict = {}, ): response, user_input_message = None, None @@ -144,6 +145,10 @@ async def operate_environment( # Add interrupt query to current operator run if interrupt_query := get_message_from_queue(interrupt_queue): + if interrupt_query == abort_message: + cancellation_event.set() + logger.debug(f"Operator run cancelled by user {user} via interrupt queue.") + break # Add the interrupt query as a new user message to the research conversation history logger.info(f"Continuing operator run with the new instruction: {interrupt_query}") operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query)) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index c8c23c8f..09c7651a 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -69,6 +69,7 @@ from khoj.routers.helpers import ( generate_mermaidjs_diagram, generate_summary_from_files, get_conversation_command, + get_message_from_queue, is_query_empty, is_ready_to_chat, read_chat_stream, @@ -672,7 +673,7 @@ async def event_generator( common: CommonQueryParams, headers: Headers, request_obj: Request | WebSocket, - interrupt_queue: asyncio.Queue = None, + parent_interrupt_queue: asyncio.Queue = None, ): # Access the parameters from the body q = body.q @@ -697,8 +698,11 @@ async def event_generator( user: KhojUser = user_scope.object is_subscribed = has_required_scope(request_obj, ["premium"]) q = unquote(q) + defiltered_query = defilter_query(q) train_of_thought = [] cancellation_event = asyncio.Event() + child_interrupt_queue: asyncio.Queue = asyncio.Queue() + event_delimiter = "␃🔚␗" tracer: dict = { "mid": turn_id, @@ -744,6 +748,7 @@ async def event_generator( disconnect_monitor_task = None async def monitor_disconnection(): + nonlocal q, defiltered_query if isinstance(request_obj, Request): try: msg = await request_obj.receive() @@ -779,12 +784,23 @@ async def event_generator( except Exception as e: logger.error(f"Error in disconnect monitor: {e}") elif isinstance(request_obj, WebSocket): - while request_obj.client_state == WebSocketState.CONNECTED: + while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set(): await asyncio.sleep(1) - logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.") - cancellation_event.set() - if conversation: + # Check if any interrupt query is received + if interrupt_query := get_message_from_queue(parent_interrupt_queue): + if interrupt_query == event_delimiter: + cancellation_event.set() + logger.debug(f"Chat cancelled by user {user} via interrupt queue.") + else: + # Pass the interrupt query to child tasks + logger.info(f"Continuing chat with the new instruction: {interrupt_query}") + await child_interrupt_queue.put(interrupt_query) + q += f"\n\n{interrupt_query}" + defiltered_query += f"\n\n{defilter_query(interrupt_query)}" + + logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.") + if conversation and cancellation_event.is_set(): await asyncio.shield( save_to_conversation_log( q, @@ -821,7 +837,6 @@ async def event_generator( async def send_event(event_type: ChatEvent, data: str | dict): nonlocal ttft, train_of_thought - event_delimiter = "␃🔚␗" if cancellation_event.is_set(): return try: @@ -1025,7 +1040,8 @@ async def event_generator( query_files=attached_file_context, tracer=tracer, cancellation_event=cancellation_event, - interrupt_queue=interrupt_queue, + interrupt_queue=child_interrupt_queue, + abort_message=event_delimiter, ): if isinstance(research_result, ResearchIteration): if research_result.summarizedResult: @@ -1218,6 +1234,7 @@ async def event_generator( send_status_func=partial(send_event, ChatEvent.STATUS), agent=agent, cancellation_event=cancellation_event, + interrupt_queue=child_interrupt_queue, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1471,9 +1488,17 @@ async def chat_ws( if data.get("type") == "interrupt": if current_task and not current_task.done(): # Send interrupt signal to the ongoing task - await interrupt_queue.put(data.get("query", "")) - logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}") - await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"})) + abort_message = "␃🔚␗" + await interrupt_queue.put(data.get("query") or abort_message) + logger.info( + f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}" + ) + if data.get("query"): + ack_type = "interrupt_message_acknowledged" + await websocket.send_text(json.dumps({"type": ack_type})) + else: + ack_type = "interrupt_acknowledged" + await websocket.send_text(json.dumps({"type": ack_type})) else: logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}") continue diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 050fd3f8..616acae2 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -224,6 +224,7 @@ async def research( query_files: str = None, cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, + abort_message: str = "␃🔚␗", ): max_document_searches = 7 max_online_searches = 3 @@ -246,6 +247,10 @@ async def research( # Update the query for the current research iteration if interrupt_query := get_message_from_queue(interrupt_queue): + if interrupt_query == abort_message: + cancellation_event.set() + logger.debug(f"Research cancelled by user {user} via interrupt queue.") + break # Add the interrupt query as a new user message to the research conversation history logger.info( f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"