mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Save and restore research from partial state
This commit is contained in:
@@ -110,9 +110,12 @@ class InformationCollectionIteration:
|
||||
|
||||
|
||||
def construct_iteration_history(
|
||||
query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
|
||||
previous_iterations: List[InformationCollectionIteration],
|
||||
previous_iteration_prompt: str,
|
||||
query: str = None,
|
||||
) -> list[dict]:
|
||||
previous_iterations_history = []
|
||||
iteration_history: list[dict] = []
|
||||
previous_iteration_messages: list[dict] = []
|
||||
for idx, iteration in enumerate(previous_iterations):
|
||||
iteration_data = previous_iteration_prompt.format(
|
||||
tool=iteration.tool,
|
||||
@@ -121,23 +124,19 @@ def construct_iteration_history(
|
||||
index=idx + 1,
|
||||
)
|
||||
|
||||
previous_iterations_history.append({"type": "text", "text": iteration_data})
|
||||
previous_iteration_messages.append({"type": "text", "text": iteration_data})
|
||||
|
||||
return (
|
||||
[
|
||||
{
|
||||
"by": "you",
|
||||
"message": query,
|
||||
},
|
||||
if previous_iteration_messages:
|
||||
if query:
|
||||
iteration_history.append({"by": "you", "message": query})
|
||||
iteration_history.append(
|
||||
{
|
||||
"by": "khoj",
|
||||
"intent": {"type": "remember", "query": query},
|
||||
"message": previous_iterations_history,
|
||||
},
|
||||
]
|
||||
if previous_iterations_history
|
||||
else []
|
||||
)
|
||||
"message": previous_iteration_messages,
|
||||
}
|
||||
)
|
||||
return iteration_history
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
@@ -285,6 +284,7 @@ async def save_to_conversation_log(
|
||||
generated_images: List[str] = [],
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_mermaidjs_diagram: str = None,
|
||||
research_results: Optional[List[InformationCollectionIteration]] = None,
|
||||
train_of_thought: List[Any] = [],
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
@@ -302,6 +302,7 @@ async def save_to_conversation_log(
|
||||
"onlineContext": online_results,
|
||||
"codeContext": code_results,
|
||||
"operatorContext": operator_results,
|
||||
"researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
|
||||
"automationId": automation_id,
|
||||
"trainOfThought": train_of_thought,
|
||||
"turnId": turn_id,
|
||||
|
||||
@@ -687,6 +687,7 @@ async def chat(
|
||||
start_time = time.perf_counter()
|
||||
ttft = None
|
||||
chat_metadata: dict = {}
|
||||
conversation = None
|
||||
user: KhojUser = request.user.object
|
||||
is_subscribed = has_required_scope(request, ["premium"])
|
||||
q = unquote(q)
|
||||
@@ -720,6 +721,20 @@ async def chat(
|
||||
for file in raw_query_files:
|
||||
query_files[file.name] = file.content
|
||||
|
||||
research_results: List[InformationCollectionIteration] = []
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
operator_results: Dict[str, str] = {}
|
||||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[Any] = []
|
||||
attached_file_context = gather_raw_query_files(query_files)
|
||||
|
||||
generated_images: List[str] = []
|
||||
generated_files: List[FileAttachment] = []
|
||||
generated_mermaidjs_diagram: str = None
|
||||
generated_asset_results: Dict = dict()
|
||||
program_execution_context: List[str] = []
|
||||
|
||||
# Create a task to monitor for disconnections
|
||||
disconnect_monitor_task = None
|
||||
|
||||
@@ -727,8 +742,34 @@ async def chat(
|
||||
try:
|
||||
msg = await request.receive()
|
||||
if msg["type"] == "http.disconnect":
|
||||
logger.debug(f"User {user} disconnected from {common.client} client.")
|
||||
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
|
||||
cancellation_event.set()
|
||||
# ensure partial chat state saved on interrupt
|
||||
# shield the save against task cancellation
|
||||
if conversation:
|
||||
await asyncio.shield(
|
||||
save_to_conversation_log(
|
||||
q,
|
||||
chat_response="",
|
||||
user=user,
|
||||
meta_log=meta_log,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
research_results=research_results,
|
||||
inferred_queries=inferred_queries,
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
generated_images=generated_images,
|
||||
raw_generated_files=generated_asset_results,
|
||||
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
||||
tracer=tracer,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in disconnect monitor: {e}")
|
||||
|
||||
@@ -746,7 +787,6 @@ async def chat(
|
||||
nonlocal ttft, train_of_thought
|
||||
event_delimiter = "␃🔚␗"
|
||||
if cancellation_event.is_set():
|
||||
logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
|
||||
return
|
||||
try:
|
||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||
@@ -770,9 +810,6 @@ async def chat(
|
||||
yield data
|
||||
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
|
||||
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
||||
except asyncio.CancelledError as e:
|
||||
if cancellation_event.is_set():
|
||||
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
|
||||
except Exception as e:
|
||||
if not cancellation_event.is_set():
|
||||
logger.error(
|
||||
@@ -883,21 +920,25 @@ async def chat(
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
meta_log = conversation.conversation_log
|
||||
|
||||
research_results: List[InformationCollectionIteration] = []
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
operator_results: Dict[str, str] = {}
|
||||
generated_asset_results: Dict = dict()
|
||||
## Extract Document References
|
||||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[Any] = []
|
||||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||
attached_file_context = gather_raw_query_files(query_files)
|
||||
|
||||
generated_images: List[str] = []
|
||||
generated_files: List[FileAttachment] = []
|
||||
generated_mermaidjs_diagram: str = None
|
||||
program_execution_context: List[str] = []
|
||||
# If interrupted message in DB
|
||||
if (
|
||||
conversation
|
||||
and conversation.messages
|
||||
and conversation.messages[-1].by == "khoj"
|
||||
and not conversation.messages[-1].message
|
||||
):
|
||||
# Populate context from interrupted message
|
||||
last_message = conversation.messages[-1]
|
||||
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
|
||||
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
|
||||
operator_results = last_message.operatorContext or {}
|
||||
compiled_references = [ref.model_dump() for ref in last_message.context or []]
|
||||
research_results = [
|
||||
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
|
||||
]
|
||||
# Drop the interrupted message from conversation history
|
||||
meta_log["chat"].pop()
|
||||
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
try:
|
||||
@@ -936,6 +977,7 @@ async def chat(
|
||||
return
|
||||
|
||||
defiltered_query = defilter_query(q)
|
||||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||
|
||||
if conversation_commands == [ConversationCommand.Research]:
|
||||
async for research_result in execute_information_collection(
|
||||
@@ -943,12 +985,13 @@ async def chat(
|
||||
query=defiltered_query,
|
||||
conversation_id=conversation_id,
|
||||
conversation_history=meta_log,
|
||||
previous_iterations=research_results,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
user_name=user_name,
|
||||
location=location,
|
||||
file_filters=conversation.file_filters if conversation else [],
|
||||
file_filters=file_filters,
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
cancellation_event=cancellation_event,
|
||||
@@ -973,7 +1016,6 @@ async def chat(
|
||||
logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
|
||||
|
||||
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
||||
file_filters = conversation.file_filters if conversation else []
|
||||
# Skip trying to summarize if
|
||||
if (
|
||||
# summarization intent was inferred
|
||||
@@ -1362,7 +1404,7 @@ async def chat(
|
||||
|
||||
# Check if the user has disconnected
|
||||
if cancellation_event.is_set():
|
||||
logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.")
|
||||
logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
|
||||
# Cancel the disconnect monitor task if it is still running
|
||||
await cancel_disconnect_monitor()
|
||||
return
|
||||
|
||||
@@ -1392,6 +1392,7 @@ async def agenerate_chat_response(
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
research_results=research_results,
|
||||
inferred_queries=inferred_queries,
|
||||
client_application=client_application,
|
||||
conversation_id=str(conversation.id),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional, Type
|
||||
@@ -141,7 +142,7 @@ async def apick_next_tool(
|
||||
query = f"[placeholder for user attached images]\n{query}"
|
||||
|
||||
# Construct chat history with user and iteration history with researcher agent for context
|
||||
previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration)
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
|
||||
iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history}
|
||||
|
||||
# Plan function execution for the next tool
|
||||
@@ -212,6 +213,7 @@ async def execute_information_collection(
|
||||
query: str,
|
||||
conversation_id: str,
|
||||
conversation_history: dict,
|
||||
previous_iterations: List[InformationCollectionIteration],
|
||||
query_images: List[str],
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
@@ -227,11 +229,20 @@ async def execute_information_collection(
|
||||
max_webpages_to_read = 1
|
||||
current_iteration = 0
|
||||
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
|
||||
previous_iterations: List[InformationCollectionIteration] = []
|
||||
|
||||
# Incorporate previous partial research into current research chat history
|
||||
research_conversation_history = deepcopy(conversation_history)
|
||||
if current_iteration := len(previous_iterations) > 0:
|
||||
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
research_conversation_history["chat"] = (
|
||||
research_conversation_history.get("chat", []) + previous_iterations_history
|
||||
)
|
||||
|
||||
while current_iteration < MAX_ITERATIONS:
|
||||
# Check for cancellation at the start of each iteration
|
||||
if cancellation_event and cancellation_event.is_set():
|
||||
logger.debug(f"User {user} disconnected client. Research cancelled.")
|
||||
logger.debug(f"Research cancelled. User {user} disconnected client.")
|
||||
break
|
||||
|
||||
online_results: Dict = dict()
|
||||
@@ -243,7 +254,7 @@ async def execute_information_collection(
|
||||
|
||||
async for result in apick_next_tool(
|
||||
query,
|
||||
conversation_history,
|
||||
research_conversation_history,
|
||||
user,
|
||||
location,
|
||||
user_name,
|
||||
|
||||
Reference in New Issue
Block a user