Save and restore research from partial state

This commit is contained in:
Debanjum
2025-05-20 16:00:40 -07:00
parent a83c36fa05
commit 3cd6e1a9a6
4 changed files with 97 additions and 42 deletions

View File

@@ -110,9 +110,12 @@ class InformationCollectionIteration:
def construct_iteration_history( def construct_iteration_history(
query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str previous_iterations: List[InformationCollectionIteration],
previous_iteration_prompt: str,
query: str = None,
) -> list[dict]: ) -> list[dict]:
previous_iterations_history = [] iteration_history: list[dict] = []
previous_iteration_messages: list[dict] = []
for idx, iteration in enumerate(previous_iterations): for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format( iteration_data = previous_iteration_prompt.format(
tool=iteration.tool, tool=iteration.tool,
@@ -121,23 +124,19 @@ def construct_iteration_history(
index=idx + 1, index=idx + 1,
) )
previous_iterations_history.append({"type": "text", "text": iteration_data}) previous_iteration_messages.append({"type": "text", "text": iteration_data})
return ( if previous_iteration_messages:
[ if query:
{ iteration_history.append({"by": "you", "message": query})
"by": "you", iteration_history.append(
"message": query,
},
{ {
"by": "khoj", "by": "khoj",
"intent": {"type": "remember", "query": query}, "intent": {"type": "remember", "query": query},
"message": previous_iterations_history, "message": previous_iteration_messages,
}, }
] )
if previous_iterations_history return iteration_history
else []
)
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
@@ -285,6 +284,7 @@ async def save_to_conversation_log(
generated_images: List[str] = [], generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [], raw_generated_files: List[FileAttachment] = [],
generated_mermaidjs_diagram: str = None, generated_mermaidjs_diagram: str = None,
research_results: Optional[List[InformationCollectionIteration]] = None,
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {}, tracer: Dict[str, Any] = {},
): ):
@@ -302,6 +302,7 @@ async def save_to_conversation_log(
"onlineContext": online_results, "onlineContext": online_results,
"codeContext": code_results, "codeContext": code_results,
"operatorContext": operator_results, "operatorContext": operator_results,
"researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
"automationId": automation_id, "automationId": automation_id,
"trainOfThought": train_of_thought, "trainOfThought": train_of_thought,
"turnId": turn_id, "turnId": turn_id,

View File

@@ -687,6 +687,7 @@ async def chat(
start_time = time.perf_counter() start_time = time.perf_counter()
ttft = None ttft = None
chat_metadata: dict = {} chat_metadata: dict = {}
conversation = None
user: KhojUser = request.user.object user: KhojUser = request.user.object
is_subscribed = has_required_scope(request, ["premium"]) is_subscribed = has_required_scope(request, ["premium"])
q = unquote(q) q = unquote(q)
@@ -720,6 +721,20 @@ async def chat(
for file in raw_query_files: for file in raw_query_files:
query_files[file.name] = file.content query_files[file.name] = file.content
research_results: List[InformationCollectionIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: Dict[str, str] = {}
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
attached_file_context = gather_raw_query_files(query_files)
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_mermaidjs_diagram: str = None
generated_asset_results: Dict = dict()
program_execution_context: List[str] = []
# Create a task to monitor for disconnections # Create a task to monitor for disconnections
disconnect_monitor_task = None disconnect_monitor_task = None
@@ -727,8 +742,34 @@ async def chat(
try: try:
msg = await request.receive() msg = await request.receive()
if msg["type"] == "http.disconnect": if msg["type"] == "http.disconnect":
logger.debug(f"User {user} disconnected from {common.client} client.") logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
cancellation_event.set() cancellation_event.set()
# ensure partial chat state saved on interrupt
# shield the save against task cancellation
if conversation:
await asyncio.shield(
save_to_conversation_log(
q,
chat_response="",
user=user,
meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
)
except Exception as e: except Exception as e:
logger.error(f"Error in disconnect monitor: {e}") logger.error(f"Error in disconnect monitor: {e}")
@@ -746,7 +787,6 @@ async def chat(
nonlocal ttft, train_of_thought nonlocal ttft, train_of_thought
event_delimiter = "␃🔚␗" event_delimiter = "␃🔚␗"
if cancellation_event.is_set(): if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
return return
try: try:
if event_type == ChatEvent.END_LLM_RESPONSE: if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -770,9 +810,6 @@ async def chat(
yield data yield data
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
if cancellation_event.is_set():
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
except Exception as e: except Exception as e:
if not cancellation_event.is_set(): if not cancellation_event.is_set():
logger.error( logger.error(
@@ -883,21 +920,25 @@ async def chat(
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log meta_log = conversation.conversation_log
research_results: List[InformationCollectionIteration] = [] # If interrupted message in DB
online_results: Dict = dict() if (
code_results: Dict = dict() conversation
operator_results: Dict[str, str] = {} and conversation.messages
generated_asset_results: Dict = dict() and conversation.messages[-1].by == "khoj"
## Extract Document References and not conversation.messages[-1].message
compiled_references: List[Any] = [] ):
inferred_queries: List[Any] = [] # Populate context from interrupted message
file_filters = conversation.file_filters if conversation and conversation.file_filters else [] last_message = conversation.messages[-1]
attached_file_context = gather_raw_query_files(query_files) online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
generated_images: List[str] = [] operator_results = last_message.operatorContext or {}
generated_files: List[FileAttachment] = [] compiled_references = [ref.model_dump() for ref in last_message.context or []]
generated_mermaidjs_diagram: str = None research_results = [
program_execution_context: List[str] = [] InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
]
# Drop the interrupted message from conversation history
meta_log["chat"].pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
if conversation_commands == [ConversationCommand.Default]: if conversation_commands == [ConversationCommand.Default]:
try: try:
@@ -936,6 +977,7 @@ async def chat(
return return
defiltered_query = defilter_query(q) defiltered_query = defilter_query(q)
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
if conversation_commands == [ConversationCommand.Research]: if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection( async for research_result in execute_information_collection(
@@ -943,12 +985,13 @@ async def chat(
query=defiltered_query, query=defiltered_query,
conversation_id=conversation_id, conversation_id=conversation_id,
conversation_history=meta_log, conversation_history=meta_log,
previous_iterations=research_results,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
user_name=user_name, user_name=user_name,
location=location, location=location,
file_filters=conversation.file_filters if conversation else [], file_filters=file_filters,
query_files=attached_file_context, query_files=attached_file_context,
tracer=tracer, tracer=tracer,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
@@ -973,7 +1016,6 @@ async def chat(
logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}') logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if # Skip trying to summarize if
if ( if (
# summarization intent was inferred # summarization intent was inferred
@@ -1362,7 +1404,7 @@ async def chat(
# Check if the user has disconnected # Check if the user has disconnected
if cancellation_event.is_set(): if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.") logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
# Cancel the disconnect monitor task if it is still running # Cancel the disconnect monitor task if it is still running
await cancel_disconnect_monitor() await cancel_disconnect_monitor()
return return

View File

@@ -1392,6 +1392,7 @@ async def agenerate_chat_response(
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,
operator_results=operator_results, operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
client_application=client_application, client_application=client_application,
conversation_id=str(conversation.id), conversation_id=str(conversation.id),

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
import os import os
from copy import deepcopy
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Callable, Dict, List, Optional, Type from typing import Callable, Dict, List, Optional, Type
@@ -141,7 +142,7 @@ async def apick_next_tool(
query = f"[placeholder for user attached images]\n{query}" query = f"[placeholder for user attached images]\n{query}"
# Construct chat history with user and iteration history with researcher agent for context # Construct chat history with user and iteration history with researcher agent for context
previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration) previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history} iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history}
# Plan function execution for the next tool # Plan function execution for the next tool
@@ -212,6 +213,7 @@ async def execute_information_collection(
query: str, query: str,
conversation_id: str, conversation_id: str,
conversation_history: dict, conversation_history: dict,
previous_iterations: List[InformationCollectionIteration],
query_images: List[str], query_images: List[str],
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
@@ -227,11 +229,20 @@ async def execute_information_collection(
max_webpages_to_read = 1 max_webpages_to_read = 1
current_iteration = 0 current_iteration = 0
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
previous_iterations: List[InformationCollectionIteration] = []
# Incorporate previous partial research into current research chat history
research_conversation_history = deepcopy(conversation_history)
if current_iteration := len(previous_iterations) > 0:
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
research_conversation_history["chat"] = (
research_conversation_history.get("chat", []) + previous_iterations_history
)
while current_iteration < MAX_ITERATIONS: while current_iteration < MAX_ITERATIONS:
# Check for cancellation at the start of each iteration # Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set(): if cancellation_event and cancellation_event.is_set():
logger.debug(f"User {user} disconnected client. Research cancelled.") logger.debug(f"Research cancelled. User {user} disconnected client.")
break break
online_results: Dict = dict() online_results: Dict = dict()
@@ -243,7 +254,7 @@ async def execute_information_collection(
async for result in apick_next_tool( async for result in apick_next_tool(
query, query,
conversation_history, research_conversation_history,
user, user,
location, location,
user_name, user_name,