diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index bf7be40d..36aa001d 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -110,9 +110,12 @@ class InformationCollectionIteration: def construct_iteration_history( - query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str + previous_iterations: List[InformationCollectionIteration], + previous_iteration_prompt: str, + query: str = None, ) -> list[dict]: - previous_iterations_history = [] + iteration_history: list[dict] = [] + previous_iteration_messages: list[dict] = [] for idx, iteration in enumerate(previous_iterations): iteration_data = previous_iteration_prompt.format( tool=iteration.tool, @@ -121,23 +124,19 @@ def construct_iteration_history( index=idx + 1, ) - previous_iterations_history.append({"type": "text", "text": iteration_data}) + previous_iteration_messages.append({"type": "text", "text": iteration_data}) - return ( - [ - { - "by": "you", - "message": query, - }, + if previous_iteration_messages: + if query: + iteration_history.append({"by": "you", "message": query}) + iteration_history.append( { "by": "khoj", "intent": {"type": "remember", "query": query}, - "message": previous_iterations_history, - }, - ] - if previous_iterations_history - else [] - ) + "message": previous_iteration_messages, + } + ) + return iteration_history def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: @@ -285,6 +284,7 @@ async def save_to_conversation_log( generated_images: List[str] = [], raw_generated_files: List[FileAttachment] = [], generated_mermaidjs_diagram: str = None, + research_results: Optional[List[InformationCollectionIteration]] = None, train_of_thought: List[Any] = [], tracer: Dict[str, Any] = {}, ): @@ -302,6 +302,7 @@ async def save_to_conversation_log( "onlineContext": online_results, "codeContext": code_results, "operatorContext": operator_results, + "researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None, "automationId": automation_id, "trainOfThought": train_of_thought, "turnId": turn_id, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 7b78063d..bf72121c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -687,6 +687,7 @@ async def chat( start_time = time.perf_counter() ttft = None chat_metadata: dict = {} + conversation = None user: KhojUser = request.user.object is_subscribed = has_required_scope(request, ["premium"]) q = unquote(q) @@ -720,6 +721,20 @@ async def chat( for file in raw_query_files: query_files[file.name] = file.content + research_results: List[InformationCollectionIteration] = [] + online_results: Dict = dict() + code_results: Dict = dict() + operator_results: Dict[str, str] = {} + compiled_references: List[Any] = [] + inferred_queries: List[Any] = [] + attached_file_context = gather_raw_query_files(query_files) + + generated_images: List[str] = [] + generated_files: List[FileAttachment] = [] + generated_mermaidjs_diagram: str = None + generated_asset_results: Dict = dict() + program_execution_context: List[str] = [] + # Create a task to monitor for disconnections disconnect_monitor_task = None @@ -727,8 +742,34 @@ async def chat( try: msg = await request.receive() if msg["type"] == "http.disconnect": - logger.debug(f"User {user} disconnected from {common.client} client.") + logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.") cancellation_event.set() + # ensure partial chat state saved on interrupt + # shield the save against task cancellation + if conversation: + await asyncio.shield( + save_to_conversation_log( + q, + chat_response="", + user=user, + meta_log=meta_log, + compiled_references=compiled_references, + online_results=online_results, + code_results=code_results, + operator_results=operator_results, + research_results=research_results, + inferred_queries=inferred_queries, + client_application=request.user.client_app, + conversation_id=conversation_id, + query_images=uploaded_images, + train_of_thought=train_of_thought, + raw_query_files=raw_query_files, + generated_images=generated_images, + raw_generated_files=generated_asset_results, + generated_mermaidjs_diagram=generated_mermaidjs_diagram, + tracer=tracer, + ) + ) except Exception as e: logger.error(f"Error in disconnect monitor: {e}") @@ -746,7 +787,6 @@ async def chat( nonlocal ttft, train_of_thought event_delimiter = "␃🔚␗" if cancellation_event.is_set(): - logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.") return try: if event_type == ChatEvent.END_LLM_RESPONSE: @@ -770,9 +810,6 @@ async def chat( yield data elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) - except asyncio.CancelledError as e: - if cancellation_event.is_set(): - logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.") except Exception as e: if not cancellation_event.is_set(): logger.error( @@ -883,21 +920,25 @@ async def chat( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") meta_log = conversation.conversation_log - research_results: List[InformationCollectionIteration] = [] - online_results: Dict = dict() - code_results: Dict = dict() - operator_results: Dict[str, str] = {} - generated_asset_results: Dict = dict() - ## Extract Document References - compiled_references: List[Any] = [] - inferred_queries: List[Any] = [] - file_filters = conversation.file_filters if conversation and conversation.file_filters else [] - attached_file_context = gather_raw_query_files(query_files) - - generated_images: List[str] = [] - generated_files: List[FileAttachment] = [] - generated_mermaidjs_diagram: str = None - program_execution_context: List[str] = [] + # If interrupted message in DB + if ( + conversation + and conversation.messages + and conversation.messages[-1].by == "khoj" + and not conversation.messages[-1].message + ): + # Populate context from interrupted message + last_message = conversation.messages[-1] + online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} + code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} + operator_results = last_message.operatorContext or {} + compiled_references = [ref.model_dump() for ref in last_message.context or []] + research_results = [ + InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or [] + ] + # Drop the interrupted message from conversation history + meta_log["chat"].pop() + logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") if conversation_commands == [ConversationCommand.Default]: try: @@ -936,6 +977,7 @@ async def chat( return defiltered_query = defilter_query(q) + file_filters = conversation.file_filters if conversation and conversation.file_filters else [] if conversation_commands == [ConversationCommand.Research]: async for research_result in execute_information_collection( @@ -943,12 +985,13 @@ async def chat( query=defiltered_query, conversation_id=conversation_id, conversation_history=meta_log, + previous_iterations=research_results, query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), user_name=user_name, location=location, - file_filters=conversation.file_filters if conversation else [], + file_filters=file_filters, query_files=attached_file_context, tracer=tracer, cancellation_event=cancellation_event, @@ -973,7 +1016,6 @@ async def chat( logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}') used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - file_filters = conversation.file_filters if conversation else [] # Skip trying to summarize if if ( # summarization intent was inferred @@ -1362,7 +1404,7 @@ async def chat( # Check if the user has disconnected if cancellation_event.is_set(): - logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.") + logger.debug(f"Stopping LLM response to user {user} on {common.client} client.") # Cancel the disconnect monitor task if it is still running await cancel_disconnect_monitor() return diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a91f51a0..c1ddb82d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1392,6 +1392,7 @@ async def agenerate_chat_response( online_results=online_results, code_results=code_results, operator_results=operator_results, + research_results=research_results, inferred_queries=inferred_queries, client_application=client_application, conversation_id=str(conversation.id), diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 2f8157b4..93efee1f 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -1,6 +1,7 @@ import asyncio import logging import os +from copy import deepcopy from datetime import datetime from enum import Enum from typing import Callable, Dict, List, Optional, Type @@ -141,7 +142,7 @@ async def apick_next_tool( query = f"[placeholder for user attached images]\n{query}" # Construct chat history with user and iteration history with researcher agent for context - previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration) + previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query) iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history} # Plan function execution for the next tool @@ -212,6 +213,7 @@ async def execute_information_collection( query: str, conversation_id: str, conversation_history: dict, + previous_iterations: List[InformationCollectionIteration], query_images: List[str], agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -227,11 +229,20 @@ async def execute_information_collection( max_webpages_to_read = 1 current_iteration = 0 MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) - previous_iterations: List[InformationCollectionIteration] = [] + + # Incorporate previous partial research into current research chat history + research_conversation_history = deepcopy(conversation_history) + if current_iteration := len(previous_iterations) > 0: + logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.") + previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) + research_conversation_history["chat"] = ( + research_conversation_history.get("chat", []) + previous_iterations_history + ) + while current_iteration < MAX_ITERATIONS: # Check for cancellation at the start of each iteration if cancellation_event and cancellation_event.is_set(): - logger.debug(f"User {user} disconnected client. Research cancelled.") + logger.debug(f"Research cancelled. User {user} disconnected client.") break online_results: Dict = dict() @@ -243,7 +254,7 @@ async def execute_information_collection( async for result in apick_next_tool( query, - conversation_history, + research_conversation_history, user, location, user_name,