Save and restore research from partial state

This commit is contained in:
Debanjum
2025-05-20 16:00:40 -07:00
parent a83c36fa05
commit 3cd6e1a9a6
4 changed files with 97 additions and 42 deletions

View File

@@ -110,9 +110,12 @@ class InformationCollectionIteration:
def construct_iteration_history(
query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
previous_iterations: List[InformationCollectionIteration],
previous_iteration_prompt: str,
query: str = None,
) -> list[dict]:
previous_iterations_history = []
iteration_history: list[dict] = []
previous_iteration_messages: list[dict] = []
for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format(
tool=iteration.tool,
@@ -121,23 +124,19 @@ def construct_iteration_history(
index=idx + 1,
)
previous_iterations_history.append({"type": "text", "text": iteration_data})
previous_iteration_messages.append({"type": "text", "text": iteration_data})
return (
[
{
"by": "you",
"message": query,
},
if previous_iteration_messages:
if query:
iteration_history.append({"by": "you", "message": query})
iteration_history.append(
{
"by": "khoj",
"intent": {"type": "remember", "query": query},
"message": previous_iterations_history,
},
]
if previous_iterations_history
else []
)
"message": previous_iteration_messages,
}
)
return iteration_history
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
@@ -285,6 +284,7 @@ async def save_to_conversation_log(
generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [],
generated_mermaidjs_diagram: str = None,
research_results: Optional[List[InformationCollectionIteration]] = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
@@ -302,6 +302,7 @@ async def save_to_conversation_log(
"onlineContext": online_results,
"codeContext": code_results,
"operatorContext": operator_results,
"researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,

View File

@@ -687,6 +687,7 @@ async def chat(
start_time = time.perf_counter()
ttft = None
chat_metadata: dict = {}
conversation = None
user: KhojUser = request.user.object
is_subscribed = has_required_scope(request, ["premium"])
q = unquote(q)
@@ -720,6 +721,20 @@ async def chat(
for file in raw_query_files:
query_files[file.name] = file.content
research_results: List[InformationCollectionIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: Dict[str, str] = {}
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
attached_file_context = gather_raw_query_files(query_files)
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_mermaidjs_diagram: str = None
generated_asset_results: Dict = dict()
program_execution_context: List[str] = []
# Create a task to monitor for disconnections
disconnect_monitor_task = None
@@ -727,8 +742,34 @@ async def chat(
try:
msg = await request.receive()
if msg["type"] == "http.disconnect":
logger.debug(f"User {user} disconnected from {common.client} client.")
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
cancellation_event.set()
# ensure partial chat state saved on interrupt
# shield the save against task cancellation
if conversation:
await asyncio.shield(
save_to_conversation_log(
q,
chat_response="",
user=user,
meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
)
except Exception as e:
logger.error(f"Error in disconnect monitor: {e}")
@@ -746,7 +787,6 @@ async def chat(
nonlocal ttft, train_of_thought
event_delimiter = "␃🔚␗"
if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
return
try:
if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -770,9 +810,6 @@ async def chat(
yield data
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
if cancellation_event.is_set():
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
except Exception as e:
if not cancellation_event.is_set():
logger.error(
@@ -883,21 +920,25 @@ async def chat(
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log
research_results: List[InformationCollectionIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: Dict[str, str] = {}
generated_asset_results: Dict = dict()
## Extract Document References
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = gather_raw_query_files(query_files)
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_mermaidjs_diagram: str = None
program_execution_context: List[str] = []
# If interrupted message in DB
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
# Populate context from interrupted message
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
operator_results = last_message.operatorContext or {}
compiled_references = [ref.model_dump() for ref in last_message.context or []]
research_results = [
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
]
# Drop the interrupted message from conversation history
meta_log["chat"].pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
if conversation_commands == [ConversationCommand.Default]:
try:
@@ -936,6 +977,7 @@ async def chat(
return
defiltered_query = defilter_query(q)
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection(
@@ -943,12 +985,13 @@ async def chat(
query=defiltered_query,
conversation_id=conversation_id,
conversation_history=meta_log,
previous_iterations=research_results,
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
user_name=user_name,
location=location,
file_filters=conversation.file_filters if conversation else [],
file_filters=file_filters,
query_files=attached_file_context,
tracer=tracer,
cancellation_event=cancellation_event,
@@ -973,7 +1016,6 @@ async def chat(
logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
@@ -1362,7 +1404,7 @@ async def chat(
# Check if the user has disconnected
if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.")
logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
# Cancel the disconnect monitor task if it is still running
await cancel_disconnect_monitor()
return

View File

@@ -1392,6 +1392,7 @@ async def agenerate_chat_response(
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=client_application,
conversation_id=str(conversation.id),

View File

@@ -1,6 +1,7 @@
import asyncio
import logging
import os
from copy import deepcopy
from datetime import datetime
from enum import Enum
from typing import Callable, Dict, List, Optional, Type
@@ -141,7 +142,7 @@ async def apick_next_tool(
query = f"[placeholder for user attached images]\n{query}"
# Construct chat history with user and iteration history with researcher agent for context
previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history}
# Plan function execution for the next tool
@@ -212,6 +213,7 @@ async def execute_information_collection(
query: str,
conversation_id: str,
conversation_history: dict,
previous_iterations: List[InformationCollectionIteration],
query_images: List[str],
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -227,11 +229,20 @@ async def execute_information_collection(
max_webpages_to_read = 1
current_iteration = 0
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
previous_iterations: List[InformationCollectionIteration] = []
# Incorporate previous partial research into current research chat history
research_conversation_history = deepcopy(conversation_history)
if current_iteration := len(previous_iterations) > 0:
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
research_conversation_history["chat"] = (
research_conversation_history.get("chat", []) + previous_iterations_history
)
while current_iteration < MAX_ITERATIONS:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.debug(f"User {user} disconnected client. Research cancelled.")
logger.debug(f"Research cancelled. User {user} disconnected client.")
break
online_results: Dict = dict()
@@ -243,7 +254,7 @@ async def execute_information_collection(
async for result in apick_next_tool(
query,
conversation_history,
research_conversation_history,
user,
location,
user_name,