diff --git a/src/khoj/main.py b/src/khoj/main.py index 76895891..50da2624 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -220,7 +220,16 @@ def set_state(args): def start_server(app, host=None, port=None, socket=None): logger.info("🌖 Khoj is ready to engage") if socket: - uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None) + uvicorn.run( + app, + proxy_headers=True, + uds=socket, + log_level="debug" if state.verbose > 1 else "info", + use_colors=True, + log_config=None, + ws_ping_timeout=300, + timeout_keep_alive=60, + ) else: uvicorn.run( app, @@ -229,6 +238,7 @@ def start_server(app, host=None, port=None, socket=None): log_level="debug" if state.verbose > 1 else "info", use_colors=True, log_config=None, + ws_ping_timeout=300, timeout_keep_alive=60, **state.ssl_config if state.ssl_config else {}, ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index b1ea3ece..04b50bfd 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -10,9 +10,18 @@ from typing import Any, Dict, List, Optional from urllib.parse import unquote from asgiref.sync import sync_to_async -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + WebSocket, + WebSocketDisconnect, +) from fastapi.responses import RedirectResponse, Response, StreamingResponse +from fastapi.websockets import WebSocketState from starlette.authentication import has_required_scope, requires +from starlette.requests import Headers from khoj.app.settings import ALLOWED_HOSTS from khoj.database.adapters import ( @@ -657,19 +666,12 @@ def delete_message(request: Request, delete_request: DeleteMessageRequestBody) - return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404) -@api_chat.post("") -@requires(["authenticated"]) -async def chat( - request: Request, - common: CommonQueryParams, +async def event_generator( body: ChatRequestBody, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - ), - image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)), + user_scope: Any, + common: CommonQueryParams, + headers: Headers, + request_obj: Request | WebSocket, ): # Access the parameters from the body q = body.q @@ -688,65 +690,62 @@ async def chat( raw_query_files = body.files interrupt_flag = body.interrupt - async def event_generator(q: str, images: list[str]): - start_time = time.perf_counter() - ttft = None - chat_metadata: dict = {} - conversation = None - user: KhojUser = request.user.object - is_subscribed = has_required_scope(request, ["premium"]) - q = unquote(q) - train_of_thought = [] - nonlocal conversation_id - nonlocal raw_query_files - cancellation_event = asyncio.Event() + start_time = time.perf_counter() + ttft = None + chat_metadata: dict = {} + conversation = None + user: KhojUser = user_scope.object + is_subscribed = has_required_scope(request_obj, ["premium"]) + q = unquote(q) + train_of_thought = [] + cancellation_event = asyncio.Event() - tracer: dict = { - "mid": turn_id, - "cid": conversation_id, - "uid": user.id, - "khoj_version": state.khoj_version, - } + tracer: dict = { + "mid": turn_id, + "cid": conversation_id, + "uid": user.id, + "khoj_version": state.khoj_version, + } - uploaded_images: list[str] = [] - if images: - for image in images: - decoded_string = unquote(image) - base64_data = decoded_string.split(",", 1)[1] - image_bytes = base64.b64decode(base64_data) - webp_image_bytes = convert_image_to_webp(image_bytes) - uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id) - if not uploaded_image: - base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8") - uploaded_image = f"data:image/webp;base64,{base64_webp_image}" - uploaded_images.append(uploaded_image) + uploaded_images: list[str] = [] + if raw_images: + for image in raw_images: + decoded_string = unquote(image) + base64_data = decoded_string.split(",", 1)[1] + image_bytes = base64.b64decode(base64_data) + webp_image_bytes = convert_image_to_webp(image_bytes) + uploaded_image = upload_user_image_to_bucket(webp_image_bytes, user.id) + if not uploaded_image: + base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8") + uploaded_image = f"data:image/webp;base64,{base64_webp_image}" + uploaded_images.append(uploaded_image) - query_files: Dict[str, str] = {} - if raw_query_files: - for file in raw_query_files: - query_files[file.name] = file.content + query_files: Dict[str, str] = {} + if raw_query_files: + for file in raw_query_files: + query_files[file.name] = file.content - research_results: List[ResearchIteration] = [] - online_results: Dict = dict() - code_results: Dict = dict() - operator_results: List[OperatorRun] = [] - compiled_references: List[Any] = [] - inferred_queries: List[Any] = [] - attached_file_context = gather_raw_query_files(query_files) + research_results: List[ResearchIteration] = [] + online_results: Dict = dict() + code_results: Dict = dict() + operator_results: List[OperatorRun] = [] + compiled_references: List[Any] = [] + inferred_queries: List[Any] = [] + attached_file_context = gather_raw_query_files(query_files) - generated_images: List[str] = [] - generated_files: List[FileAttachment] = [] - generated_mermaidjs_diagram: str = None - generated_asset_results: Dict = dict() - program_execution_context: List[str] = [] - chat_history: List[ChatMessageModel] = [] + generated_images: List[str] = [] + generated_files: List[FileAttachment] = [] + generated_mermaidjs_diagram: str = None + generated_asset_results: Dict = dict() + program_execution_context: List[str] = [] - # Create a task to monitor for disconnections - disconnect_monitor_task = None + # Create a task to monitor for disconnections + disconnect_monitor_task = None - async def monitor_disconnection(): + async def monitor_disconnection(): + if isinstance(request_obj, Request): try: - msg = await request.receive() + msg = await request_obj.receive() if msg["type"] == "http.disconnect": logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.") cancellation_event.set() @@ -765,7 +764,7 @@ async def chat( operator_results=operator_results, research_results=research_results, inferred_queries=inferred_queries, - client_application=request.user.client_app, + client_application=user_scope.client_app, conversation_id=conversation_id, query_images=uploaded_images, train_of_thought=train_of_thought, @@ -778,683 +777,817 @@ async def chat( ) except Exception as e: logger.error(f"Error in disconnect monitor: {e}") + elif isinstance(request_obj, WebSocket): + while request_obj.client_state == WebSocketState.CONNECTED: + await asyncio.sleep(1) - # Cancel the disconnect monitor task if it is still running - async def cancel_disconnect_monitor(): - if disconnect_monitor_task and not disconnect_monitor_task.done(): - logger.debug(f"Cancelling disconnect monitor task for user {user}") - disconnect_monitor_task.cancel() - try: - await disconnect_monitor_task - except asyncio.CancelledError: - pass - - async def send_event(event_type: ChatEvent, data: str | dict): - nonlocal ttft, train_of_thought - event_delimiter = "␃🔚␗" - if cancellation_event.is_set(): - return - try: - if event_type == ChatEvent.END_LLM_RESPONSE: - collect_telemetry() - elif event_type == ChatEvent.START_LLM_RESPONSE: - ttft = time.perf_counter() - start_time - elif event_type == ChatEvent.STATUS: - train_of_thought.append({"type": event_type.value, "data": data}) - elif event_type == ChatEvent.THOUGHT: - # Append the data to the last thought as thoughts are streamed - if ( - len(train_of_thought) > 0 - and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value - and type(train_of_thought[-1]["data"]) == type(data) == str - ): - train_of_thought[-1]["data"] += data - else: - train_of_thought.append({"type": event_type.value, "data": data}) - - if event_type == ChatEvent.MESSAGE: - yield data - elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: - yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) - except Exception as e: - if not cancellation_event.is_set(): - logger.error( - f"Failed to stream chat API response to {user} on {common.client}: {e}.", - exc_info=True, + logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.") + cancellation_event.set() + if conversation: + await asyncio.shield( + save_to_conversation_log( + q, + chat_response="", + user=user, + chat_history=chat_history, + compiled_references=compiled_references, + online_results=online_results, + code_results=code_results, + operator_results=operator_results, + research_results=research_results, + inferred_queries=inferred_queries, + client_application=user_scope.client_app, + conversation_id=conversation_id, + query_images=uploaded_images, + train_of_thought=train_of_thought, + raw_query_files=raw_query_files, + generated_images=generated_images, + raw_generated_files=generated_asset_results, + generated_mermaidjs_diagram=generated_mermaidjs_diagram, + tracer=tracer, ) - finally: - if not cancellation_event.is_set(): - yield event_delimiter - # Cancel the disconnect monitor task if it is still running - if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE: - await cancel_disconnect_monitor() - - async def send_llm_response(response: str, usage: dict = None): - # Check if the client is still connected - if cancellation_event.is_set(): - return - # Send Chat Response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - async for result in send_event(ChatEvent.MESSAGE, response): - yield result - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - # Send Usage Metadata once llm interactions are complete - if usage: - async for event in send_event(ChatEvent.USAGE, usage): - yield event - async for result in send_event(ChatEvent.END_RESPONSE, ""): - yield result - - def collect_telemetry(): - # Gather chat response telemetry - nonlocal chat_metadata - latency = time.perf_counter() - start_time - cmd_set = set([cmd.value for cmd in conversation_commands]) - cost = (tracer.get("usage", {}) or {}).get("cost", 0) - chat_metadata = chat_metadata or {} - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None - chat_metadata["cost"] = f"{cost:.5f}" - chat_metadata["latency"] = f"{latency:.3f}" - if ttft: - chat_metadata["ttft_latency"] = f"{ttft:.3f}" - logger.info(f"Chat response time to first token: {ttft:.3f} seconds") - logger.info(f"Chat response total time: {latency:.3f} seconds") - logger.info(f"Chat response cost: ${cost:.5f}") - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - client=common.client, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - metadata=chat_metadata, - ) - - # Start the disconnect monitor in the background - disconnect_monitor_task = asyncio.create_task(monitor_disconnection()) - - if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")): - yield result - return - - # Automated tasks are handled before to allow mixing them with other conversation commands - cmds_to_rate_limit = [] - is_automated_task = False - if q.startswith("/automated_task"): - is_automated_task = True - q = q.replace("/automated_task", "").lstrip() - cmds_to_rate_limit += [ConversationCommand.AutomatedTask] - - # Extract conversation command from query - conversation_commands = [get_conversation_command(query=q)] - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, - client_application=request.user.client_app, - conversation_id=conversation_id, - title=title, - create_new=body.create_new, - ) - if not conversation: - async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")): - yield result - return - conversation_id = str(conversation.id) - - async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}): - yield event - - agent: Agent | None = None - default_agent = await AgentAdapters.aget_default_agent() - if conversation.agent and conversation.agent != default_agent: - agent = conversation.agent - - if not conversation.agent: - conversation.agent = default_agent - await conversation.asave() - agent = default_agent - - await is_ready_to_chat(user) - user_name = await aget_user_name(user) - location = None - if city or region or country or country_code: - location = LocationData(city=city, region=region, country=country, country_code=country_code) - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - chat_history = conversation.messages - - # If interrupt flag is set, wait for the previous turn to be saved before proceeding - if interrupt_flag: - max_wait_time = 20.0 # seconds - wait_interval = 0.3 # seconds - wait_start = wait_current = time.time() - while wait_current - wait_start < max_wait_time: - # Refresh conversation to check if interrupted message saved to DB - conversation = await ConversationAdapters.aget_conversation_by_user( - user, - client_application=request.user.client_app, - conversation_id=conversation_id, ) + + # Cancel the disconnect monitor task if it is still running + async def cancel_disconnect_monitor(): + if disconnect_monitor_task and not disconnect_monitor_task.done(): + logger.debug(f"Cancelling disconnect monitor task for user {user}") + disconnect_monitor_task.cancel() + try: + await disconnect_monitor_task + except asyncio.CancelledError: + pass + + async def send_event(event_type: ChatEvent, data: str | dict): + nonlocal ttft, train_of_thought + event_delimiter = "␃🔚␗" + if cancellation_event.is_set(): + return + try: + if event_type == ChatEvent.END_LLM_RESPONSE: + collect_telemetry() + elif event_type == ChatEvent.START_LLM_RESPONSE: + ttft = time.perf_counter() - start_time + elif event_type == ChatEvent.STATUS: + train_of_thought.append({"type": event_type.value, "data": data}) + elif event_type == ChatEvent.THOUGHT: + # Append the data to the last thought as thoughts are streamed if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message + len(train_of_thought) > 0 + and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value + and type(train_of_thought[-1]["data"]) == type(data) == str ): - logger.info(f"Detected interrupted message save to conversation {conversation_id}.") - break - await asyncio.sleep(wait_interval) - wait_current = time.time() - - if wait_current - wait_start >= max_wait_time: - logger.warning( - f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context." - ) - - # If interrupted message in DB - if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message - ): - # Populate context from interrupted message - last_message = conversation.messages[-1] - online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} - code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} - compiled_references = [ref.model_dump() for ref in last_message.context or []] - research_results = [ - ResearchIteration(**iter_dict) - for iter_dict in last_message.researchContext or [] - if iter_dict.get("summarizedResult") - ] - operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] - train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] - # Drop the interrupted message from conversation history - chat_history.pop() - logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") - - if conversation_commands == [ConversationCommand.Default]: - try: - chosen_io = await aget_data_sources_and_output_format( - q, - chat_history, - is_automated_task, - user=user, - query_images=uploaded_images, - agent=agent, - query_files=attached_file_context, - tracer=tracer, - ) - except ValueError as e: - logger.error(f"Error getting data sources and output format: {e}. Falling back to default.") - conversation_commands = [ConversationCommand.General] - - conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")] - - # If we're doing research, we don't want to do anything else - if ConversationCommand.Research in conversation_commands: - conversation_commands = [ConversationCommand.Research] - - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"): - yield result - - cmds_to_rate_limit += conversation_commands - for cmd in cmds_to_rate_limit: - try: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - except HTTPException as e: - async for result in send_llm_response(str(e.detail), tracer.get("usage")): - yield result - return - - defiltered_query = defilter_query(q) - file_filters = conversation.file_filters if conversation and conversation.file_filters else [] - - if conversation_commands == [ConversationCommand.Research]: - async for research_result in research( - user=user, - query=defiltered_query, - conversation_id=conversation_id, - conversation_history=chat_history, - previous_iterations=list(research_results), - query_images=uploaded_images, - agent=agent, - send_status_func=partial(send_event, ChatEvent.STATUS), - user_name=user_name, - location=location, - file_filters=file_filters, - query_files=attached_file_context, - tracer=tracer, - cancellation_event=cancellation_event, - ): - if isinstance(research_result, ResearchIteration): - if research_result.summarizedResult: - if research_result.onlineContext: - online_results.update(research_result.onlineContext) - if research_result.codeContext: - code_results.update(research_result.codeContext) - if research_result.context: - compiled_references.extend(research_result.context) - if not research_results or research_results[-1] is not research_result: - research_results.append(research_result) + train_of_thought[-1]["data"] += data else: - yield research_result + train_of_thought.append({"type": event_type.value, "data": data}) - # Track operator results across research and operator iterations - # This relies on two conditions: - # 1. Check to append new (partial) operator results - # Relies on triggering this check on every status updates. - # Status updates cascade up from operator to research to chat api on every step. - # 2. Keep operator results in sync with each research operator step - # Relies on python object references to ensure operator results - # are implicitly kept in sync after the initial append - if ( - research_results - and research_results[-1].operatorContext - and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext) - ): - operator_results.append(research_results[-1].operatorContext) - - # researched_results = await extract_relevant_info(q, researched_results, agent) - if state.verbose > 1: - logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}') - - # Gather Context - ## Extract Document References - if not ConversationCommand.Research in conversation_commands: - try: - async for result in search_documents( - q, - (n or 7), - d, - user, - chat_history, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - query_images=uploaded_images, - agent=agent, - query_files=attached_file_context, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - except Exception as e: - error_message = ( - f"Error searching knowledge base: {e}. Attempting to respond without document references." - ) - logger.error(error_message, exc_info=True) - async for result in send_event( - ChatEvent.STATUS, "Document search failed. I'll try respond without document references" - ): - yield result - - if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): - yield result - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")): - yield result - return - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - ## Gather Online References - if ConversationCommand.Online in conversation_commands: - try: - async for result in search_online( - defiltered_query, - chat_history, - location, - user, - partial(send_event, ChatEvent.STATUS), - custom_filters=[], - max_online_searches=3, - query_images=uploaded_images, - query_files=attached_file_context, - agent=agent, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - online_results = result - except Exception as e: - error_message = f"Error searching online: {e}. Attempting to respond without online results" - logger.warning(error_message) - async for result in send_event( - ChatEvent.STATUS, "Online search failed. I'll try respond without online references" - ): - yield result - - ## Gather Webpage References - if ConversationCommand.Webpage in conversation_commands: - try: - async for result in read_webpages( - defiltered_query, - chat_history, - location, - user, - partial(send_event, ChatEvent.STATUS), - max_webpages_to_read=1, - query_images=uploaded_images, - agent=agent, - query_files=attached_file_context, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages = result - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} - - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): - yield result - except Exception as e: - logger.warning( - f"Error reading webpages: {e}. Attempting to respond without webpage results", + if event_type == ChatEvent.MESSAGE: + yield data + elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: + yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) + except Exception as e: + if not cancellation_event.is_set(): + logger.error( + f"Failed to stream chat API response to {user} on {common.client}: {e}.", exc_info=True, ) - async for result in send_event( - ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references" - ): - yield result + finally: + if not cancellation_event.is_set(): + yield event_delimiter + # Cancel the disconnect monitor task if it is still running + if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE: + await cancel_disconnect_monitor() - ## Gather Code Results - if ConversationCommand.Code in conversation_commands: - try: - context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}" - async for result in run_code( - defiltered_query, - chat_history, - context, - location, - user, - partial(send_event, ChatEvent.STATUS), - query_images=uploaded_images, - agent=agent, - query_files=attached_file_context, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - code_results = result - except ValueError as e: - program_execution_context.append(f"Failed to run code") - logger.warning( - f"Failed to use code tool: {e}. Attempting to respond without code results", - exc_info=True, - ) - if ConversationCommand.Operator in conversation_commands: - try: - async for result in operate_environment( - defiltered_query, - user, - chat_history, - location, - list(operator_results)[-1] if operator_results else None, - query_images=uploaded_images, - query_files=attached_file_context, - send_status_func=partial(send_event, ChatEvent.STATUS), - agent=agent, - cancellation_event=cancellation_event, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - elif isinstance(result, OperatorRun): - if not operator_results or operator_results[-1] is not result: - operator_results.append(result) - # Add webpages visited while operating browser to references - if result.webpages: - if not online_results.get(defiltered_query): - online_results[defiltered_query] = {"webpages": result.webpages} - elif not online_results[defiltered_query].get("webpages"): - online_results[defiltered_query]["webpages"] = result.webpages - else: - online_results[defiltered_query]["webpages"] += result.webpages - except ValueError as e: - program_execution_context.append(f"Browser operation error: {e}") - logger.warning(f"Failed to operate browser with {e}", exc_info=True) - async for result in send_event( - ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately" - ): - yield result - - ## Send Gathered References - unique_online_results = deduplicate_organic_results(online_results) - async for result in send_event( - ChatEvent.REFERENCES, - { - "inferredQueries": inferred_queries, - "context": compiled_references, - "onlineContext": unique_online_results, - "codeContext": code_results, - }, - ): + async def send_llm_response(response: str, usage: dict = None): + # Check if the client is still connected + if cancellation_event.is_set(): + return + # Send Chat Response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + async for result in send_event(ChatEvent.MESSAGE, response): + yield result + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + # Send Usage Metadata once llm interactions are complete + if usage: + async for event in send_event(ChatEvent.USAGE, usage): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): yield result - # Generate Output - ## Generate Image Output - if ConversationCommand.Image in conversation_commands: - async for result in text_to_image( + def collect_telemetry(): + # Gather chat response telemetry + nonlocal chat_metadata + latency = time.perf_counter() - start_time + cmd_set = set([cmd.value for cmd in conversation_commands]) + cost = (tracer.get("usage", {}) or {}).get("cost", 0) + chat_metadata = chat_metadata or {} + chat_metadata["conversation_command"] = cmd_set + chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None + chat_metadata["cost"] = f"{cost:.5f}" + chat_metadata["latency"] = f"{latency:.3f}" + if ttft: + chat_metadata["ttft_latency"] = f"{ttft:.3f}" + logger.info(f"Chat response time to first token: {ttft:.3f} seconds") + logger.info(f"Chat response total time: {latency:.3f} seconds") + logger.info(f"Chat response cost: ${cost:.5f}") + update_telemetry_state( + request=request_obj, + telemetry_type="api", + api="chat", + client=common.client, + user_agent=headers.get("user-agent"), + host=headers.get("host"), + metadata=chat_metadata, + ) + + # Start the disconnect monitor in the background + disconnect_monitor_task = asyncio.create_task(monitor_disconnection()) + + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")): + yield result + return + + # Automated tasks are handled before to allow mixing them with other conversation commands + cmds_to_rate_limit = [] + is_automated_task = False + if q.startswith("/automated_task"): + is_automated_task = True + q = q.replace("/automated_task", "").lstrip() + cmds_to_rate_limit += [ConversationCommand.AutomatedTask] + + # Extract conversation command from query + conversation_commands = [get_conversation_command(query=q)] + + conversation = await ConversationAdapters.aget_conversation_by_user( + user, + client_application=user_scope.client_app, + conversation_id=conversation_id, + title=title, + create_new=body.create_new, + ) + if not conversation: + async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")): + yield result + return + conversation_id = str(conversation.id) + + async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}): + yield event + + agent: Agent | None = None + default_agent = await AgentAdapters.aget_default_agent() + if conversation.agent and conversation.agent != default_agent: + agent = conversation.agent + + if not conversation.agent: + conversation.agent = default_agent + await conversation.asave() + agent = default_agent + + await is_ready_to_chat(user) + user_name = await aget_user_name(user) + location = None + if city or region or country or country_code: + location = LocationData(city=city, region=region, country=country, country_code=country_code) + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + chat_history = conversation.messages + + # If interrupt flag is set, wait for the previous turn to be saved before proceeding + if interrupt_flag: + max_wait_time = 20.0 # seconds + wait_interval = 0.3 # seconds + wait_start = wait_current = time.time() + while wait_current - wait_start < max_wait_time: + # Refresh conversation to check if interrupted message saved to DB + conversation = await ConversationAdapters.aget_conversation_by_user( + user, + client_application=user_scope.client_app, + conversation_id=conversation_id, + ) + if ( + conversation + and conversation.messages + and conversation.messages[-1].by == "khoj" + and not conversation.messages[-1].message + ): + logger.info(f"Detected interrupted message save to conversation {conversation_id}.") + break + await asyncio.sleep(wait_interval) + wait_current = time.time() + + if wait_current - wait_start >= max_wait_time: + logger.warning( + f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context." + ) + + # If interrupted message in DB + if ( + conversation + and conversation.messages + and conversation.messages[-1].by == "khoj" + and not conversation.messages[-1].message + ): + # Populate context from interrupted message + last_message = conversation.messages[-1] + online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} + code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} + compiled_references = [ref.model_dump() for ref in last_message.context or []] + research_results = [ + ResearchIteration(**iter_dict) + for iter_dict in last_message.researchContext or [] + if iter_dict.get("summarizedResult") + ] + operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] + train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] + # Drop the interrupted message from conversation history + chat_history.pop() + logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") + + if conversation_commands == [ConversationCommand.Default]: + try: + chosen_io = await aget_data_sources_and_output_format( + q, + chat_history, + is_automated_task, + user=user, + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ) + except ValueError as e: + logger.error(f"Error getting data sources and output format: {e}. Falling back to default.") + conversation_commands = [ConversationCommand.General] + + conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")] + + # If we're doing research, we don't want to do anything else + if ConversationCommand.Research in conversation_commands: + conversation_commands = [ConversationCommand.Research] + + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"): + yield result + + cmds_to_rate_limit += conversation_commands + for cmd in cmds_to_rate_limit: + try: + await conversation_command_rate_limiter.update_and_check_if_valid(request_obj, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + except HTTPException as e: + async for result in send_llm_response(str(e.detail), tracer.get("usage")): + yield result + return + + defiltered_query = defilter_query(q) + file_filters = conversation.file_filters if conversation and conversation.file_filters else [] + + if conversation_commands == [ConversationCommand.Research]: + async for research_result in research( + user=user, + query=defiltered_query, + conversation_id=conversation_id, + conversation_history=chat_history, + previous_iterations=list(research_results), + query_images=uploaded_images, + agent=agent, + send_status_func=partial(send_event, ChatEvent.STATUS), + user_name=user_name, + location=location, + file_filters=file_filters, + query_files=attached_file_context, + tracer=tracer, + cancellation_event=cancellation_event, + ): + if isinstance(research_result, ResearchIteration): + if research_result.summarizedResult: + if research_result.onlineContext: + online_results.update(research_result.onlineContext) + if research_result.codeContext: + code_results.update(research_result.codeContext) + if research_result.context: + compiled_references.extend(research_result.context) + if not research_results or research_results[-1] is not research_result: + research_results.append(research_result) + else: + yield research_result + + # Track operator results across research and operator iterations + # This relies on two conditions: + # 1. Check to append new (partial) operator results + # Relies on triggering this check on every status updates. + # Status updates cascade up from operator to research to chat api on every step. + # 2. Keep operator results in sync with each research operator step + # Relies on python object references to ensure operator results + # are implicitly kept in sync after the initial append + if ( + research_results + and research_results[-1].operatorContext + and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext) + ): + operator_results.append(research_results[-1].operatorContext) + + # researched_results = await extract_relevant_info(q, researched_results, agent) + if state.verbose > 1: + logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}') + + # Gather Context + ## Extract Document References + if not ConversationCommand.Research in conversation_commands: + try: + async for result in search_documents( + q, + (n or 7), + d, + user, + chat_history, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + except Exception as e: + error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references." + logger.error(error_message, exc_info=True) + async for result in send_event( + ChatEvent.STATUS, "Document search failed. I'll try respond without document references" + ): + yield result + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + # Strip only leading # from headings + headings = headings.replace("#", "") + async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): + yield result + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): + async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")): + yield result + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + ## Gather Online References + if ConversationCommand.Online in conversation_commands: + try: + async for result in search_online( + defiltered_query, + chat_history, + location, + user, + partial(send_event, ChatEvent.STATUS), + custom_filters=[], + max_online_searches=3, + query_images=uploaded_images, + query_files=attached_file_context, + agent=agent, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + online_results = result + except Exception as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_event( + ChatEvent.STATUS, "Online search failed. I'll try respond without online references" + ): + yield result + + ## Gather Webpage References + if ConversationCommand.Webpage in conversation_commands: + try: + async for result in read_webpages( + defiltered_query, + chat_history, + location, + user, + partial(send_event, ChatEvent.STATUS), + max_webpages_to_read=1, + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): + yield result + except Exception as e: + logger.warning( + f"Error reading webpages: {e}. Attempting to respond without webpage results", + exc_info=True, + ) + async for result in send_event( + ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references" + ): + yield result + + ## Gather Code Results + if ConversationCommand.Code in conversation_commands: + try: + context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}" + async for result in run_code( + defiltered_query, + chat_history, + context, + location, + user, + partial(send_event, ChatEvent.STATUS), + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + code_results = result + except ValueError as e: + program_execution_context.append(f"Failed to run code") + logger.warning( + f"Failed to use code tool: {e}. Attempting to respond without code results", + exc_info=True, + ) + if ConversationCommand.Operator in conversation_commands: + try: + async for result in operate_environment( defiltered_query, user, chat_history, - location_data=location, - references=compiled_references, - online_results=online_results, - send_status_func=partial(send_event, ChatEvent.STATUS), + location, + list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + send_status_func=partial(send_event, ChatEvent.STATUS), + agent=agent, + cancellation_event=cancellation_event, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - else: - generated_image, status_code, improved_image_prompt = result - - inferred_queries.append(improved_image_prompt) - if generated_image is None or status_code != 200: - program_execution_context.append(f"Failed to generate image with {improved_image_prompt}") - async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"): - yield result - else: - generated_images.append(generated_image) - - generated_asset_results["images"] = { - "query": improved_image_prompt, - } - - async for result in send_event( - ChatEvent.GENERATED_ASSETS, - { - "images": [generated_image], - }, - ): - yield result - - if ConversationCommand.Diagram in conversation_commands: - async for result in send_event(ChatEvent.STATUS, f"Creating diagram"): + elif isinstance(result, OperatorRun): + if not operator_results or operator_results[-1] is not result: + operator_results.append(result) + # Add webpages visited while operating browser to references + if result.webpages: + if not online_results.get(defiltered_query): + online_results[defiltered_query] = {"webpages": result.webpages} + elif not online_results[defiltered_query].get("webpages"): + online_results[defiltered_query]["webpages"] = result.webpages + else: + online_results[defiltered_query]["webpages"] += result.webpages + except ValueError as e: + program_execution_context.append(f"Browser operation error: {e}") + logger.warning(f"Failed to operate browser with {e}", exc_info=True) + async for result in send_event( + ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately" + ): yield result - inferred_queries = [] - diagram_description = "" + ## Send Gathered References + unique_online_results = deduplicate_organic_results(online_results) + async for result in send_event( + ChatEvent.REFERENCES, + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "onlineContext": unique_online_results, + "codeContext": code_results, + }, + ): + yield result - async for result in generate_mermaidjs_diagram( - q=defiltered_query, - chat_history=chat_history, - location_data=location, - note_references=compiled_references, - online_results=online_results, - query_images=uploaded_images, - user=user, - agent=agent, - send_status_func=partial(send_event, ChatEvent.STATUS), - query_files=attached_file_context, - tracer=tracer, + # Generate Output + ## Generate Image Output + if ConversationCommand.Image in conversation_commands: + async for result in text_to_image( + defiltered_query, + user, + chat_history, + location_data=location, + references=compiled_references, + online_results=online_results, + send_status_func=partial(send_event, ChatEvent.STATUS), + query_images=uploaded_images, + agent=agent, + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + generated_image, status_code, improved_image_prompt = result + + inferred_queries.append(improved_image_prompt) + if generated_image is None or status_code != 200: + program_execution_context.append(f"Failed to generate image with {improved_image_prompt}") + async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"): + yield result + else: + generated_images.append(generated_image) + + generated_asset_results["images"] = { + "query": improved_image_prompt, + } + + async for result in send_event( + ChatEvent.GENERATED_ASSETS, + { + "images": [generated_image], + }, ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - better_diagram_description_prompt, mermaidjs_diagram_description = result - if better_diagram_description_prompt and mermaidjs_diagram_description: - inferred_queries.append(better_diagram_description_prompt) - diagram_description = mermaidjs_diagram_description + yield result - generated_mermaidjs_diagram = diagram_description - - generated_asset_results["diagrams"] = { - "query": better_diagram_description_prompt, - } - - async for result in send_event( - ChatEvent.GENERATED_ASSETS, - { - "mermaidjsDiagram": mermaidjs_diagram_description, - }, - ): - yield result - else: - error_message = "Failed to generate diagram. Please try again later." - program_execution_context.append( - prompts.failed_diagram_generation.format( - attempted_diagram=better_diagram_description_prompt - ) - ) - - async for result in send_event(ChatEvent.STATUS, error_message): - yield result - - # Check if the user has disconnected - if cancellation_event.is_set(): - logger.debug(f"Stopping LLM response to user {user} on {common.client} client.") - # Cancel the disconnect monitor task if it is still running - await cancel_disconnect_monitor() - return - - ## Generate Text Output - async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): + if ConversationCommand.Diagram in conversation_commands: + async for result in send_event(ChatEvent.STATUS, f"Creating diagram"): yield result - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - chat_history, - conversation, - compiled_references, - online_results, - code_results, - operator_results, - research_results, - user, - location, - user_name, - uploaded_images, - attached_file_context, - generated_files, - program_execution_context, - generated_asset_results, - is_subscribed, - tracer, - ) + inferred_queries = [] + diagram_description = "" - full_response = "" - async for item in llm_response: - # Should not happen with async generator. Skip. - if item is None or not isinstance(item, ResponseWithThought): - logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.") - continue - if cancellation_event.is_set(): - break - message = item.text - full_response += message if message else "" - if item.thought: - async for result in send_event(ChatEvent.THOUGHT, item.thought): - yield result - continue + async for result in generate_mermaidjs_diagram( + q=defiltered_query, + chat_history=chat_history, + location_data=location, + note_references=compiled_references, + online_results=online_results, + query_images=uploaded_images, + user=user, + agent=agent, + send_status_func=partial(send_event, ChatEvent.STATUS), + query_files=attached_file_context, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + better_diagram_description_prompt, mermaidjs_diagram_description = result + if better_diagram_description_prompt and mermaidjs_diagram_description: + inferred_queries.append(better_diagram_description_prompt) + diagram_description = mermaidjs_diagram_description - # Start sending response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result + generated_mermaidjs_diagram = diagram_description - try: - async for result in send_event(ChatEvent.MESSAGE, message): - yield result - except Exception as e: - if not cancellation_event.is_set(): - logger.warning(f"Error during streaming. Stopping send: {e}") - break + generated_asset_results["diagrams"] = { + "query": better_diagram_description_prompt, + } - # Save conversation once finish streaming - asyncio.create_task( - save_to_conversation_log( - q, - chat_response=full_response, - user=user, - chat_history=chat_history, - compiled_references=compiled_references, - online_results=online_results, - code_results=code_results, - operator_results=operator_results, - research_results=research_results, - inferred_queries=inferred_queries, - client_application=request.user.client_app, - conversation_id=str(conversation.id), - query_images=uploaded_images, - train_of_thought=train_of_thought, - raw_query_files=raw_query_files, - generated_images=generated_images, - raw_generated_files=generated_files, - generated_mermaidjs_diagram=generated_mermaidjs_diagram, - tracer=tracer, - ) - ) + async for result in send_event( + ChatEvent.GENERATED_ASSETS, + { + "mermaidjsDiagram": mermaidjs_diagram_description, + }, + ): + yield result + else: + error_message = "Failed to generate diagram. Please try again later." + program_execution_context.append( + prompts.failed_diagram_generation.format(attempted_diagram=better_diagram_description_prompt) + ) - # Signal end of LLM response after the loop finishes - if not cancellation_event.is_set(): - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - # Send Usage Metadata once llm interactions are complete - if tracer.get("usage"): - async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): - yield event - async for result in send_event(ChatEvent.END_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") + async for result in send_event(ChatEvent.STATUS, error_message): + yield result + # Check if the user has disconnected + if cancellation_event.is_set(): + logger.debug(f"Stopping LLM response to user {user} on {common.client} client.") # Cancel the disconnect monitor task if it is still running await cancel_disconnect_monitor() + return - ## Stream Text Response - if stream: - return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain") - ## Non-Streaming Text Response + ## Generate Text Output + async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): + yield result + + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + chat_history, + conversation, + compiled_references, + online_results, + code_results, + operator_results, + research_results, + user, + location, + user_name, + uploaded_images, + attached_file_context, + generated_files, + program_execution_context, + generated_asset_results, + is_subscribed, + tracer, + ) + + full_response = "" + async for item in llm_response: + # Should not happen with async generator. Skip. + if item is None or not isinstance(item, ResponseWithThought): + logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.") + continue + if cancellation_event.is_set(): + break + message = item.text + full_response += message if message else "" + if item.thought: + async for result in send_event(ChatEvent.THOUGHT, item.thought): + yield result + continue + + # Start sending response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + + try: + async for result in send_event(ChatEvent.MESSAGE, message): + yield result + except Exception as e: + if not cancellation_event.is_set(): + logger.warning(f"Error during streaming. Stopping send: {e}") + break + + # Save conversation once finish streaming + asyncio.create_task( + save_to_conversation_log( + q, + chat_response=full_response, + user=user, + chat_history=chat_history, + compiled_references=compiled_references, + online_results=online_results, + code_results=code_results, + operator_results=operator_results, + research_results=research_results, + inferred_queries=inferred_queries, + client_application=user_scope.client_app, + conversation_id=str(conversation.id), + query_images=uploaded_images, + train_of_thought=train_of_thought, + raw_query_files=raw_query_files, + generated_images=generated_images, + raw_generated_files=generated_files, + generated_mermaidjs_diagram=generated_mermaidjs_diagram, + tracer=tracer, + ) + ) + + # Signal end of LLM response after the loop finishes + if not cancellation_event.is_set(): + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + # Send Usage Metadata once llm interactions are complete + if tracer.get("usage"): + async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") + + # Cancel the disconnect monitor task if it is still running + await cancel_disconnect_monitor() + + +@api_chat.websocket("/ws") +@requires(["authenticated"]) +async def chat_ws( + websocket: WebSocket, + common: CommonQueryParams, +): + await websocket.accept() + + # Initialize rate limiters + rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") + rate_limiter_per_day = ApiUserRateLimiter( + requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day" + ) + image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20) + + current_task = None + + try: + while True: + data = await websocket.receive_json() + + # Handle regular chat messages + # Handle regular chat messages - ensure data has required fields + if "q" not in data: + await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"})) + continue + + body = ChatRequestBody(**data) + + # Apply rate limiting manually + try: + rate_limiter_per_minute.check_websocket(websocket) + rate_limiter_per_day.check_websocket(websocket) + image_rate_limiter.check_websocket(websocket, body) + except HTTPException as e: + await websocket.send_text(json.dumps({"error": e.detail})) + continue + + # Cancel any ongoing task before starting a new one + if current_task and not current_task.done(): + current_task.cancel() + try: + await current_task + except asyncio.CancelledError: + pass + + # Create a new task for processing the chat request + current_task = asyncio.create_task(process_chat_request(websocket, body, common)) + + except WebSocketDisconnect: + logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}") + if current_task and not current_task.done(): + current_task.cancel() + except Exception as e: + logger.error(f"Error in websocket chat: {e}", exc_info=True) + if current_task and not current_task.done(): + current_task.cancel() + await websocket.close(code=1011, reason="Internal Server Error") + + +async def process_chat_request( + websocket: WebSocket, + body: ChatRequestBody, + common: CommonQueryParams, +): + """Process a single chat request with interrupt support""" + try: + # Since we are using websockets, we can ignore the stream parameter and always stream + response_iterator = event_generator( + body, + websocket.scope["user"], + common, + websocket.headers, + websocket, + ) + async for event in response_iterator: + await websocket.send_text(event) + except asyncio.CancelledError: + logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}") + raise + except Exception as e: + logger.error(f"Error processing chat request: {e}", exc_info=True) + await websocket.send_text(json.dumps({"error": "Internal server error"})) + raise + + +@api_chat.post("") +@requires(["authenticated"]) +async def chat( + request: Request, + common: CommonQueryParams, + body: ChatRequestBody, + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + ), + image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)), +): + response_iterator = event_generator( + body, + request.user, + common, + request.headers, + request, + ) + + # Stream Text Response + if body.stream: + return StreamingResponse(response_iterator, media_type="text/plain") + # Non-Streaming Text Response else: - response_iterator = event_generator(q, images=raw_images) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 48fa4252..9bfdb155 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -33,7 +33,7 @@ from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from django.utils import timezone as django_timezone -from fastapi import Depends, Header, HTTPException, Request, UploadFile +from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket from pydantic import BaseModel, EmailStr, Field from starlette.authentication import has_required_scope from starlette.requests import URL @@ -1936,6 +1936,53 @@ class ApiUserRateLimiter: # Add the current request to the cache UserRequests.objects.create(user=user, slug=self.slug) + def check_websocket(self, websocket: WebSocket): + """WebSocket-specific rate limiting method""" + # Rate limiting disabled if billing is disabled + if state.billing_enabled is False: + return + + # Rate limiting is disabled if user unauthenticated. + if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated: + return + + user: KhojUser = websocket.scope["user"].object + subscribed = has_required_scope(websocket, ["premium"]) + + # Remove requests outside of the time window + cutoff = django_timezone.now() - timedelta(seconds=self.window) + count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count() + + # Check if the user has exceeded the rate limit + if subscribed and count_requests >= self.subscribed_requests: + logger.info( + f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}." + ) + raise HTTPException( + status_code=429, + detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?", + ) + if not subscribed and count_requests >= self.requests: + if self.requests >= self.subscribed_requests: + logger.info( + f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}." + ) + raise HTTPException( + status_code=429, + detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?", + ) + + logger.info( + f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}." + ) + raise HTTPException( + status_code=429, + detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?", + ) + + # Add the current request to the cache + UserRequests.objects.create(user=user, slug=self.slug) + class ApiImageRateLimiter: def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10): @@ -1983,6 +2030,47 @@ class ApiImageRateLimiter: detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.", ) + def check_websocket(self, websocket: WebSocket, body: ChatRequestBody): + """WebSocket-specific image rate limiting method""" + if state.billing_enabled is False: + return + + # Rate limiting is disabled if user unauthenticated. + if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated: + return + + if not body.images: + return + + # Check number of images + if len(body.images) > self.max_images: + logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.") + raise HTTPException( + status_code=429, + detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.", + ) + + # Check total size of images + total_size_mb = 0.0 + for image in body.images: + # Unquote the image in case it's URL encoded + image = unquote(image) + # Assuming the image is a base64 encoded string + # Remove the data:image/jpeg;base64, part if present + if "," in image: + image = image.split(",", 1)[1] + + # Decode base64 to get the actual size + image_bytes = base64.b64decode(image) + total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB + + if total_size_mb > self.max_combined_size_mb: + logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.") + raise HTTPException( + status_code=429, + detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.", + ) + class ConversationCommandRateLimiter: def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): @@ -1991,7 +2079,7 @@ class ConversationCommandRateLimiter: self.subscribed_rate_limit = subscribed_rate_limit self.restricted_commands = [ConversationCommand.Research] - async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): + async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand): if state.billing_enabled is False: return