Easily interrupt and redirect khoj's research direction via chat

- Khoj can now save and restore research from partial state
  This triggers an interrupt that saves the partial research, then
  when a new query is sent it loads the previous partial research as
  context and continues utilizing with the new user query to orient
  its future research
- Support natural interrupt and send query behavior from web app
  This triggers an abort and send when a user sends a chat message
  while khoj is in the middle of some previous research.

This interrupt mechanism enables a more natural, interactive
research flow
This commit is contained in:
Debanjum
2025-05-27 17:57:21 -07:00
12 changed files with 209 additions and 90 deletions

View File

@@ -49,6 +49,7 @@ interface ChatBodyDataProps {
isChatSideBarOpen: boolean;
setIsChatSideBarOpen: (open: boolean) => void;
isActive?: boolean;
isParentProcessing?: boolean;
}
function ChatBodyData(props: ChatBodyDataProps) {
@@ -166,7 +167,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
sendDisabled={props.isParentProcessing || false}
chatOptionsData={props.chatOptionsData}
conversationId={conversationId}
isMobileWidth={props.isMobileWidth}
@@ -203,6 +204,7 @@ export default function Chat() {
const [abortMessageStreamController, setAbortMessageStreamController] =
useState<AbortController | null>(null);
const [triggeredAbort, setTriggeredAbort] = useState(false);
const [shouldSendWithInterrupt, setShouldSendWithInterrupt] = useState(false);
const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || {
locationData: {
@@ -239,6 +241,7 @@ export default function Chat() {
if (triggeredAbort) {
abortMessageStreamController?.abort();
handleAbortedMessage();
setShouldSendWithInterrupt(true);
setTriggeredAbort(false);
}
}, [triggeredAbort]);
@@ -335,18 +338,21 @@ export default function Chat() {
currentMessage.completed = true;
setMessages([...messages]);
setQueryToProcess("");
setProcessQuerySignal(false);
}
async function chat() {
localStorage.removeItem("message");
if (!queryToProcess || !conversationId) return;
if (!queryToProcess || !conversationId) {
setProcessQuerySignal(false);
return;
}
const chatAPI = "/api/chat?client=web";
const chatAPIBody = {
q: queryToProcess,
conversation_id: conversationId,
stream: true,
interrupt: shouldSendWithInterrupt,
...(locationData && {
city: locationData.city,
region: locationData.region,
@@ -358,6 +364,9 @@ export default function Chat() {
...(uploadedFiles && { files: uploadedFiles }),
};
// Reset the flag after using it
setShouldSendWithInterrupt(false);
const response = await fetch(chatAPI, {
method: "POST",
headers: {
@@ -481,6 +490,7 @@ export default function Chat() {
isChatSideBarOpen={isChatSideBarOpen}
setIsChatSideBarOpen={setIsChatSideBarOpen}
isActive={authenticatedData?.is_active}
isParentProcessing={processQuerySignal}
/>
</Suspense>
</div>

View File

@@ -180,13 +180,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
}, [props.isResearchModeEnabled]);
function onSendMessage() {
if (imageUploaded) {
setImageUploaded(false);
setImagePaths([]);
imageData.forEach((data) => props.sendImage(data));
}
if (!message.trim()) return;
if (!message.trim() && imageData.length === 0) return;
if (!props.isLoggedIn) {
setLoginRedirectMessage(
"Hey there, you need to be signed in to send messages to Khoj AI",
@@ -195,6 +189,17 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
return;
}
// If currently processing, trigger abort first
if (props.sendDisabled) {
props.setTriggeredAbort(true);
}
if (imageUploaded) {
setImageUploaded(false);
setImagePaths([]);
imageData.forEach((data) => props.sendImage(data));
}
let messageToSend = message.trim();
// Check if message starts with an explicit slash command
const startsWithSlashCommand =
@@ -657,7 +662,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
<Button
variant={"ghost"}
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
disabled={props.sendDisabled || !props.isLoggedIn}
disabled={!props.isLoggedIn}
onClick={handleFileButtonClick}
ref={fileInputButtonRef}
>
@@ -686,7 +691,8 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
e.key === "Enter" &&
!e.shiftKey &&
!props.isMobileWidth &&
!props.sendDisabled
!recording &&
message
) {
setImageUploaded(false);
setImagePaths([]);
@@ -725,7 +731,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
{props.sendDisabled ? (
{props.sendDisabled && !message ? (
<Button
variant="default"
className={`${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
@@ -758,8 +764,8 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</TooltipProvider>
)}
<Button
className={`${(!message || recording || props.sendDisabled) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
disabled={props.sendDisabled || !props.isLoggedIn}
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
disabled={!message || recording || !props.isLoggedIn}
onClick={onSendMessage}
>
<ArrowUp className="w-6 h-6" weight="bold" />

View File

@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
class Context(PydanticBaseModel):
compiled: str
file: str
query: str
class CodeContextFile(PydanticBaseModel):
@@ -105,6 +106,8 @@ class ChatMessage(PydanticBaseModel):
context: List[Context] = []
onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {}
researchContext: Optional[List] = None
operatorContext: Optional[Dict[str, str]] = None
created: str
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None

View File

@@ -164,7 +164,7 @@ async def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
) -> AsyncGenerator[str | ResponseWithThought, None]:
"""
Converse with user using Anthropic's Claude
"""

View File

@@ -190,7 +190,7 @@ async def converse_openai(
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
) -> AsyncGenerator[str | ResponseWithThought, None]:
"""
Converse with user using OpenAI's ChatGPT
"""

View File

@@ -110,9 +110,12 @@ class InformationCollectionIteration:
def construct_iteration_history(
query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
previous_iterations: List[InformationCollectionIteration],
previous_iteration_prompt: str,
query: str = None,
) -> list[dict]:
previous_iterations_history = []
iteration_history: list[dict] = []
previous_iteration_messages: list[dict] = []
for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format(
tool=iteration.tool,
@@ -121,23 +124,19 @@ def construct_iteration_history(
index=idx + 1,
)
previous_iterations_history.append(iteration_data)
previous_iteration_messages.append({"type": "text", "text": iteration_data})
return (
[
{
"by": "you",
"message": query,
},
if previous_iteration_messages:
if query:
iteration_history.append({"by": "you", "message": query})
iteration_history.append(
{
"by": "khoj",
"intent": {"type": "remember", "query": query},
"message": previous_iterations_history,
},
]
if previous_iterations_history
else []
)
"message": previous_iteration_messages,
}
)
return iteration_history
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
@@ -285,6 +284,7 @@ async def save_to_conversation_log(
generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [],
generated_mermaidjs_diagram: str = None,
research_results: Optional[List[InformationCollectionIteration]] = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
@@ -302,6 +302,7 @@ async def save_to_conversation_log(
"onlineContext": online_results,
"codeContext": code_results,
"operatorContext": operator_results,
"researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
@@ -341,7 +342,7 @@ Khoj: "{chat_response}"
def construct_structured_message(
message: list[str] | str,
message: list[dict] | str,
images: list[str],
model_type: str,
vision_enabled: bool,
@@ -355,11 +356,9 @@ def construct_structured_message(
ChatModel.ModelType.GOOGLE,
ChatModel.ModelType.ANTHROPIC,
]:
message = [message] if isinstance(message, str) else message
constructed_messages: List[dict[str, Any]] = [
{"type": "text", "text": message_part} for message_part in message
]
constructed_messages: List[dict[str, Any]] = (
[{"type": "text", "text": message}] if isinstance(message, str) else message
)
if not is_none_or_empty(attached_file_context):
constructed_messages.append({"type": "text", "text": attached_file_context})
@@ -368,6 +367,7 @@ def construct_structured_message(
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
return constructed_messages
message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message)
if not is_none_or_empty(attached_file_context):
return f"{attached_file_context}\n\n{message}"
@@ -421,7 +421,7 @@ def generate_chatml_messages_with_context(
# Extract Chat History for Context
chatml_messages: List[ChatMessage] = []
for chat in conversation_log.get("chat", []):
message_context = ""
message_context = []
message_attached_files = ""
generated_assets = {}
@@ -433,16 +433,6 @@ def generate_chatml_messages_with_context(
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0]
if not is_none_or_empty(chat.get("context")):
references = "\n\n".join(
{
f"# File: {item['file']}\n## {item['compiled']}\n"
for item in chat.get("context") or []
if isinstance(item, dict)
}
)
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
if chat.get("queryFiles"):
raw_query_files = chat.get("queryFiles")
query_files_dict = dict()
@@ -453,15 +443,38 @@ def generate_chatml_messages_with_context(
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
message_context += [
{
"type": "text",
"text": f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}",
}
]
if not is_none_or_empty(chat.get("codeContext")):
message_context += f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}"
message_context += [
{
"type": "text",
"text": f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}",
}
]
if not is_none_or_empty(chat.get("operatorContext")):
message_context += (
f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}"
message_context += [
{
"type": "text",
"text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}",
}
]
if not is_none_or_empty(chat.get("context")):
references = "\n\n".join(
{
f"# File: {item['file']}\n## {item['compiled']}\n"
for item in chat.get("context") or []
if isinstance(item, dict)
}
)
message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}]
if not is_none_or_empty(message_context):
reconstructed_context_message = ChatMessage(content=message_context, role="user")

View File

@@ -13,7 +13,7 @@ from io import BytesIO
from typing import Any, List
import numpy as np
from openai import AzureOpenAI, OpenAI
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletion
from PIL import Image
@@ -72,7 +72,7 @@ class GroundingAgentUitars:
def __init__(
self,
model_name: str,
client: OpenAI | AzureOpenAI,
client: AsyncOpenAI | AsyncAzureOpenAI,
max_iterations=50,
environment_type: Literal["computer", "web"] = "computer",
runtime_conf: dict = {

View File

@@ -682,11 +682,13 @@ async def chat(
timezone = body.timezone
raw_images = body.images
raw_query_files = body.files
interrupt_flag = body.interrupt
async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter()
ttft = None
chat_metadata: dict = {}
conversation = None
user: KhojUser = request.user.object
is_subscribed = has_required_scope(request, ["premium"])
q = unquote(q)
@@ -720,6 +722,20 @@ async def chat(
for file in raw_query_files:
query_files[file.name] = file.content
research_results: List[InformationCollectionIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: Dict[str, str] = {}
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
attached_file_context = gather_raw_query_files(query_files)
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_mermaidjs_diagram: str = None
generated_asset_results: Dict = dict()
program_execution_context: List[str] = []
# Create a task to monitor for disconnections
disconnect_monitor_task = None
@@ -727,8 +743,34 @@ async def chat(
try:
msg = await request.receive()
if msg["type"] == "http.disconnect":
logger.debug(f"User {user} disconnected from {common.client} client.")
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
cancellation_event.set()
# ensure partial chat state saved on interrupt
# shield the save against task cancellation
if conversation:
await asyncio.shield(
save_to_conversation_log(
q,
chat_response="",
user=user,
meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
)
except Exception as e:
logger.error(f"Error in disconnect monitor: {e}")
@@ -746,7 +788,6 @@ async def chat(
nonlocal ttft, train_of_thought
event_delimiter = "␃🔚␗"
if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
return
try:
if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -770,9 +811,6 @@ async def chat(
yield data
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
if cancellation_event.is_set():
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
except Exception as e:
if not cancellation_event.is_set():
logger.error(
@@ -883,21 +921,53 @@ async def chat(
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log
researched_results = ""
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: Dict[str, str] = {}
generated_asset_results: Dict = dict()
## Extract Document References
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = gather_raw_query_files(query_files)
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
if interrupt_flag:
max_wait_time = 20.0 # seconds
wait_interval = 0.3 # seconds
wait_start = wait_current = time.time()
while wait_current - wait_start < max_wait_time:
# Refresh conversation to check if interrupted message saved to DB
conversation = await ConversationAdapters.aget_conversation_by_user(
user,
client_application=request.user.client_app,
conversation_id=conversation_id,
)
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
break
await asyncio.sleep(wait_interval)
wait_current = time.time()
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_mermaidjs_diagram: str = None
program_execution_context: List[str] = []
if wait_current - wait_start >= max_wait_time:
logger.warning(
f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
)
# If interrupted message in DB
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
# Populate context from interrupted message
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
operator_results = last_message.operatorContext or {}
compiled_references = [ref.model_dump() for ref in last_message.context or []]
research_results = [
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
]
# Drop the interrupted message from conversation history
meta_log["chat"].pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
if conversation_commands == [ConversationCommand.Default]:
try:
@@ -936,6 +1006,7 @@ async def chat(
return
defiltered_query = defilter_query(q)
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection(
@@ -943,12 +1014,13 @@ async def chat(
query=defiltered_query,
conversation_id=conversation_id,
conversation_history=meta_log,
previous_iterations=research_results,
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
user_name=user_name,
location=location,
file_filters=conversation.file_filters if conversation else [],
file_filters=file_filters,
query_files=attached_file_context,
tracer=tracer,
cancellation_event=cancellation_event,
@@ -963,17 +1035,16 @@ async def chat(
compiled_references.extend(research_result.context)
if research_result.operatorContext:
operator_results.update(research_result.operatorContext)
researched_results += research_result.summarizedResult
research_results.append(research_result)
else:
yield research_result
# researched_results = await extract_relevant_info(q, researched_results, agent)
if state.verbose > 1:
logger.debug(f"Researched Results: {researched_results}")
logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
@@ -1362,7 +1433,7 @@ async def chat(
# Check if the user has disconnected
if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.")
logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
# Cancel the disconnect monitor task if it is still running
await cancel_disconnect_monitor()
return
@@ -1379,13 +1450,13 @@ async def chat(
online_results,
code_results,
operator_results,
research_results,
inferred_queries,
conversation_commands,
user,
request.user.client_app,
location,
user_name,
researched_results,
uploaded_images,
train_of_thought,
attached_file_context,

View File

@@ -72,7 +72,7 @@ async def update_chat_model(
if chat_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"}))
if not subscribed and chat_model.price_tier != PriceTier.FREE:
raise Response(
return Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}),
)
@@ -108,7 +108,7 @@ async def update_voice_model(
if voice_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"}))
if not subscribed and voice_model.price_tier != PriceTier.FREE:
raise Response(
return Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}),
)
@@ -143,7 +143,7 @@ async def update_paint_model(
if image_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"}))
if not subscribed and image_model.price_tier != PriceTier.FREE:
raise Response(
return Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}),
)

View File

@@ -94,6 +94,7 @@ from khoj.processor.conversation.openai.gpt import (
)
from khoj.processor.conversation.utils import (
ChatEvent,
InformationCollectionIteration,
ResponseWithThought,
clean_json,
clean_mermaidjs,
@@ -1355,13 +1356,13 @@ async def agenerate_chat_response(
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
operator_results: Dict[str, str] = {},
research_results: List[InformationCollectionIteration] = [],
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,
client_application: ClientApplication = None,
location_data: LocationData = None,
user_name: Optional[str] = None,
meta_research: str = "",
query_images: Optional[List[str]] = None,
train_of_thought: List[Any] = [],
query_files: str = None,
@@ -1391,6 +1392,7 @@ async def agenerate_chat_response(
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=client_application,
conversation_id=str(conversation.id),
@@ -1405,8 +1407,10 @@ async def agenerate_chat_response(
query_to_run = q
deepthought = False
if meta_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>"
if research_results:
compiled_research = "".join([r.summarizedResult for r in research_results if r.summarizedResult])
if compiled_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{compiled_research}\n</collected_research>"
compiled_references = []
online_results = {}
code_results = {}

View File

@@ -1,6 +1,7 @@
import asyncio
import logging
import os
from copy import deepcopy
from datetime import datetime
from enum import Enum
from typing import Callable, Dict, List, Optional, Type
@@ -141,7 +142,7 @@ async def apick_next_tool(
query = f"[placeholder for user attached images]\n{query}"
# Construct chat history with user and iteration history with researcher agent for context
previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history}
# Plan function execution for the next tool
@@ -212,6 +213,7 @@ async def execute_information_collection(
query: str,
conversation_id: str,
conversation_history: dict,
previous_iterations: List[InformationCollectionIteration],
query_images: List[str],
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -227,11 +229,20 @@ async def execute_information_collection(
max_webpages_to_read = 1
current_iteration = 0
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
previous_iterations: List[InformationCollectionIteration] = []
# Incorporate previous partial research into current research chat history
research_conversation_history = deepcopy(conversation_history)
if current_iteration := len(previous_iterations) > 0:
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
research_conversation_history["chat"] = (
research_conversation_history.get("chat", []) + previous_iterations_history
)
while current_iteration < MAX_ITERATIONS:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.debug(f"User {user} disconnected client. Research cancelled.")
logger.debug(f"Research cancelled. User {user} disconnected client.")
break
online_results: Dict = dict()
@@ -243,7 +254,7 @@ async def execute_information_collection(
async for result in apick_next_tool(
query,
conversation_history,
research_conversation_history,
user,
location,
user_name,

View File

@@ -168,6 +168,7 @@ class ChatRequestBody(BaseModel):
images: Optional[list[str]] = None
files: Optional[list[FileAttachment]] = []
create_new: Optional[bool] = False
interrupt: Optional[bool] = False
class Entry: