Easily interrupt and redirect khoj's research direction via chat

- Khoj can now save and restore research from partial state
  This triggers an interrupt that saves the partial research, then
  when a new query is sent it loads the previous partial research as
  context and continues utilizing with the new user query to orient
  its future research
- Support natural interrupt and send query behavior from web app
  This triggers an abort and send when a user sends a chat message
  while khoj is in the middle of some previous research.

This interrupt mechanism enables a more natural, interactive
research flow
This commit is contained in:
Debanjum
2025-05-27 17:57:21 -07:00
12 changed files with 209 additions and 90 deletions

View File

@@ -49,6 +49,7 @@ interface ChatBodyDataProps {
isChatSideBarOpen: boolean; isChatSideBarOpen: boolean;
setIsChatSideBarOpen: (open: boolean) => void; setIsChatSideBarOpen: (open: boolean) => void;
isActive?: boolean; isActive?: boolean;
isParentProcessing?: boolean;
} }
function ChatBodyData(props: ChatBodyDataProps) { function ChatBodyData(props: ChatBodyDataProps) {
@@ -166,7 +167,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
isLoggedIn={props.isLoggedIn} isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)} sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])} sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage} sendDisabled={props.isParentProcessing || false}
chatOptionsData={props.chatOptionsData} chatOptionsData={props.chatOptionsData}
conversationId={conversationId} conversationId={conversationId}
isMobileWidth={props.isMobileWidth} isMobileWidth={props.isMobileWidth}
@@ -203,6 +204,7 @@ export default function Chat() {
const [abortMessageStreamController, setAbortMessageStreamController] = const [abortMessageStreamController, setAbortMessageStreamController] =
useState<AbortController | null>(null); useState<AbortController | null>(null);
const [triggeredAbort, setTriggeredAbort] = useState(false); const [triggeredAbort, setTriggeredAbort] = useState(false);
const [shouldSendWithInterrupt, setShouldSendWithInterrupt] = useState(false);
const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || { const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || {
locationData: { locationData: {
@@ -239,6 +241,7 @@ export default function Chat() {
if (triggeredAbort) { if (triggeredAbort) {
abortMessageStreamController?.abort(); abortMessageStreamController?.abort();
handleAbortedMessage(); handleAbortedMessage();
setShouldSendWithInterrupt(true);
setTriggeredAbort(false); setTriggeredAbort(false);
} }
}, [triggeredAbort]); }, [triggeredAbort]);
@@ -335,18 +338,21 @@ export default function Chat() {
currentMessage.completed = true; currentMessage.completed = true;
setMessages([...messages]); setMessages([...messages]);
setQueryToProcess("");
setProcessQuerySignal(false); setProcessQuerySignal(false);
} }
async function chat() { async function chat() {
localStorage.removeItem("message"); localStorage.removeItem("message");
if (!queryToProcess || !conversationId) return; if (!queryToProcess || !conversationId) {
setProcessQuerySignal(false);
return;
}
const chatAPI = "/api/chat?client=web"; const chatAPI = "/api/chat?client=web";
const chatAPIBody = { const chatAPIBody = {
q: queryToProcess, q: queryToProcess,
conversation_id: conversationId, conversation_id: conversationId,
stream: true, stream: true,
interrupt: shouldSendWithInterrupt,
...(locationData && { ...(locationData && {
city: locationData.city, city: locationData.city,
region: locationData.region, region: locationData.region,
@@ -358,6 +364,9 @@ export default function Chat() {
...(uploadedFiles && { files: uploadedFiles }), ...(uploadedFiles && { files: uploadedFiles }),
}; };
// Reset the flag after using it
setShouldSendWithInterrupt(false);
const response = await fetch(chatAPI, { const response = await fetch(chatAPI, {
method: "POST", method: "POST",
headers: { headers: {
@@ -481,6 +490,7 @@ export default function Chat() {
isChatSideBarOpen={isChatSideBarOpen} isChatSideBarOpen={isChatSideBarOpen}
setIsChatSideBarOpen={setIsChatSideBarOpen} setIsChatSideBarOpen={setIsChatSideBarOpen}
isActive={authenticatedData?.is_active} isActive={authenticatedData?.is_active}
isParentProcessing={processQuerySignal}
/> />
</Suspense> </Suspense>
</div> </div>

View File

@@ -180,13 +180,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
}, [props.isResearchModeEnabled]); }, [props.isResearchModeEnabled]);
function onSendMessage() { function onSendMessage() {
if (imageUploaded) { if (!message.trim() && imageData.length === 0) return;
setImageUploaded(false);
setImagePaths([]);
imageData.forEach((data) => props.sendImage(data));
}
if (!message.trim()) return;
if (!props.isLoggedIn) { if (!props.isLoggedIn) {
setLoginRedirectMessage( setLoginRedirectMessage(
"Hey there, you need to be signed in to send messages to Khoj AI", "Hey there, you need to be signed in to send messages to Khoj AI",
@@ -195,6 +189,17 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
return; return;
} }
// If currently processing, trigger abort first
if (props.sendDisabled) {
props.setTriggeredAbort(true);
}
if (imageUploaded) {
setImageUploaded(false);
setImagePaths([]);
imageData.forEach((data) => props.sendImage(data));
}
let messageToSend = message.trim(); let messageToSend = message.trim();
// Check if message starts with an explicit slash command // Check if message starts with an explicit slash command
const startsWithSlashCommand = const startsWithSlashCommand =
@@ -657,7 +662,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
<Button <Button
variant={"ghost"} variant={"ghost"}
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500" className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
disabled={props.sendDisabled || !props.isLoggedIn} disabled={!props.isLoggedIn}
onClick={handleFileButtonClick} onClick={handleFileButtonClick}
ref={fileInputButtonRef} ref={fileInputButtonRef}
> >
@@ -686,7 +691,8 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
e.key === "Enter" && e.key === "Enter" &&
!e.shiftKey && !e.shiftKey &&
!props.isMobileWidth && !props.isMobileWidth &&
!props.sendDisabled !recording &&
message
) { ) {
setImageUploaded(false); setImageUploaded(false);
setImagePaths([]); setImagePaths([]);
@@ -725,7 +731,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
<TooltipProvider> <TooltipProvider>
<Tooltip> <Tooltip>
<TooltipTrigger asChild> <TooltipTrigger asChild>
{props.sendDisabled ? ( {props.sendDisabled && !message ? (
<Button <Button
variant="default" variant="default"
className={`${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`} className={`${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
@@ -758,8 +764,8 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</TooltipProvider> </TooltipProvider>
)} )}
<Button <Button
className={`${(!message || recording || props.sendDisabled) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`} className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
disabled={props.sendDisabled || !props.isLoggedIn} disabled={!message || recording || !props.isLoggedIn}
onClick={onSendMessage} onClick={onSendMessage}
> >
<ArrowUp className="w-6 h-6" weight="bold" /> <ArrowUp className="w-6 h-6" weight="bold" />

View File

@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
class Context(PydanticBaseModel): class Context(PydanticBaseModel):
compiled: str compiled: str
file: str file: str
query: str
class CodeContextFile(PydanticBaseModel): class CodeContextFile(PydanticBaseModel):
@@ -105,6 +106,8 @@ class ChatMessage(PydanticBaseModel):
context: List[Context] = [] context: List[Context] = []
onlineContext: Dict[str, OnlineContext] = {} onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {} codeContext: Dict[str, CodeContextData] = {}
researchContext: Optional[List] = None
operatorContext: Optional[Dict[str, str]] = None
created: str created: str
images: Optional[List[str]] = None images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None queryFiles: Optional[List[Dict]] = None

View File

@@ -164,7 +164,7 @@ async def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]: ) -> AsyncGenerator[str | ResponseWithThought, None]:
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
""" """

View File

@@ -190,7 +190,7 @@ async def converse_openai(
program_execution_context: List[str] = None, program_execution_context: List[str] = None,
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]: ) -> AsyncGenerator[str | ResponseWithThought, None]:
""" """
Converse with user using OpenAI's ChatGPT Converse with user using OpenAI's ChatGPT
""" """

View File

@@ -110,9 +110,12 @@ class InformationCollectionIteration:
def construct_iteration_history( def construct_iteration_history(
query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str previous_iterations: List[InformationCollectionIteration],
previous_iteration_prompt: str,
query: str = None,
) -> list[dict]: ) -> list[dict]:
previous_iterations_history = [] iteration_history: list[dict] = []
previous_iteration_messages: list[dict] = []
for idx, iteration in enumerate(previous_iterations): for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format( iteration_data = previous_iteration_prompt.format(
tool=iteration.tool, tool=iteration.tool,
@@ -121,23 +124,19 @@ def construct_iteration_history(
index=idx + 1, index=idx + 1,
) )
previous_iterations_history.append(iteration_data) previous_iteration_messages.append({"type": "text", "text": iteration_data})
return ( if previous_iteration_messages:
[ if query:
{ iteration_history.append({"by": "you", "message": query})
"by": "you", iteration_history.append(
"message": query,
},
{ {
"by": "khoj", "by": "khoj",
"intent": {"type": "remember", "query": query}, "intent": {"type": "remember", "query": query},
"message": previous_iterations_history, "message": previous_iteration_messages,
}, }
] )
if previous_iterations_history return iteration_history
else []
)
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
@@ -285,6 +284,7 @@ async def save_to_conversation_log(
generated_images: List[str] = [], generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [], raw_generated_files: List[FileAttachment] = [],
generated_mermaidjs_diagram: str = None, generated_mermaidjs_diagram: str = None,
research_results: Optional[List[InformationCollectionIteration]] = None,
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {}, tracer: Dict[str, Any] = {},
): ):
@@ -302,6 +302,7 @@ async def save_to_conversation_log(
"onlineContext": online_results, "onlineContext": online_results,
"codeContext": code_results, "codeContext": code_results,
"operatorContext": operator_results, "operatorContext": operator_results,
"researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
"automationId": automation_id, "automationId": automation_id,
"trainOfThought": train_of_thought, "trainOfThought": train_of_thought,
"turnId": turn_id, "turnId": turn_id,
@@ -341,7 +342,7 @@ Khoj: "{chat_response}"
def construct_structured_message( def construct_structured_message(
message: list[str] | str, message: list[dict] | str,
images: list[str], images: list[str],
model_type: str, model_type: str,
vision_enabled: bool, vision_enabled: bool,
@@ -355,11 +356,9 @@ def construct_structured_message(
ChatModel.ModelType.GOOGLE, ChatModel.ModelType.GOOGLE,
ChatModel.ModelType.ANTHROPIC, ChatModel.ModelType.ANTHROPIC,
]: ]:
message = [message] if isinstance(message, str) else message constructed_messages: List[dict[str, Any]] = (
[{"type": "text", "text": message}] if isinstance(message, str) else message
constructed_messages: List[dict[str, Any]] = [ )
{"type": "text", "text": message_part} for message_part in message
]
if not is_none_or_empty(attached_file_context): if not is_none_or_empty(attached_file_context):
constructed_messages.append({"type": "text", "text": attached_file_context}) constructed_messages.append({"type": "text", "text": attached_file_context})
@@ -368,6 +367,7 @@ def construct_structured_message(
constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
return constructed_messages return constructed_messages
message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message)
if not is_none_or_empty(attached_file_context): if not is_none_or_empty(attached_file_context):
return f"{attached_file_context}\n\n{message}" return f"{attached_file_context}\n\n{message}"
@@ -421,7 +421,7 @@ def generate_chatml_messages_with_context(
# Extract Chat History for Context # Extract Chat History for Context
chatml_messages: List[ChatMessage] = [] chatml_messages: List[ChatMessage] = []
for chat in conversation_log.get("chat", []): for chat in conversation_log.get("chat", []):
message_context = "" message_context = []
message_attached_files = "" message_attached_files = ""
generated_assets = {} generated_assets = {}
@@ -433,16 +433,6 @@ def generate_chatml_messages_with_context(
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0] chat_message = chat["intent"].get("inferred-queries")[0]
if not is_none_or_empty(chat.get("context")):
references = "\n\n".join(
{
f"# File: {item['file']}\n## {item['compiled']}\n"
for item in chat.get("context") or []
if isinstance(item, dict)
}
)
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
if chat.get("queryFiles"): if chat.get("queryFiles"):
raw_query_files = chat.get("queryFiles") raw_query_files = chat.get("queryFiles")
query_files_dict = dict() query_files_dict = dict()
@@ -453,15 +443,38 @@ def generate_chatml_messages_with_context(
chatml_messages.append(ChatMessage(content=message_attached_files, role=role)) chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
if not is_none_or_empty(chat.get("onlineContext")): if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" message_context += [
{
"type": "text",
"text": f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}",
}
]
if not is_none_or_empty(chat.get("codeContext")): if not is_none_or_empty(chat.get("codeContext")):
message_context += f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}" message_context += [
{
"type": "text",
"text": f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}",
}
]
if not is_none_or_empty(chat.get("operatorContext")): if not is_none_or_empty(chat.get("operatorContext")):
message_context += ( message_context += [
f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}" {
"type": "text",
"text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}",
}
]
if not is_none_or_empty(chat.get("context")):
references = "\n\n".join(
{
f"# File: {item['file']}\n## {item['compiled']}\n"
for item in chat.get("context") or []
if isinstance(item, dict)
}
) )
message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}]
if not is_none_or_empty(message_context): if not is_none_or_empty(message_context):
reconstructed_context_message = ChatMessage(content=message_context, role="user") reconstructed_context_message = ChatMessage(content=message_context, role="user")

View File

@@ -13,7 +13,7 @@ from io import BytesIO
from typing import Any, List from typing import Any, List
import numpy as np import numpy as np
from openai import AzureOpenAI, OpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
from PIL import Image from PIL import Image
@@ -72,7 +72,7 @@ class GroundingAgentUitars:
def __init__( def __init__(
self, self,
model_name: str, model_name: str,
client: OpenAI | AzureOpenAI, client: AsyncOpenAI | AsyncAzureOpenAI,
max_iterations=50, max_iterations=50,
environment_type: Literal["computer", "web"] = "computer", environment_type: Literal["computer", "web"] = "computer",
runtime_conf: dict = { runtime_conf: dict = {

View File

@@ -682,11 +682,13 @@ async def chat(
timezone = body.timezone timezone = body.timezone
raw_images = body.images raw_images = body.images
raw_query_files = body.files raw_query_files = body.files
interrupt_flag = body.interrupt
async def event_generator(q: str, images: list[str]): async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter() start_time = time.perf_counter()
ttft = None ttft = None
chat_metadata: dict = {} chat_metadata: dict = {}
conversation = None
user: KhojUser = request.user.object user: KhojUser = request.user.object
is_subscribed = has_required_scope(request, ["premium"]) is_subscribed = has_required_scope(request, ["premium"])
q = unquote(q) q = unquote(q)
@@ -720,6 +722,20 @@ async def chat(
for file in raw_query_files: for file in raw_query_files:
query_files[file.name] = file.content query_files[file.name] = file.content
research_results: List[InformationCollectionIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: Dict[str, str] = {}
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
attached_file_context = gather_raw_query_files(query_files)
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_mermaidjs_diagram: str = None
generated_asset_results: Dict = dict()
program_execution_context: List[str] = []
# Create a task to monitor for disconnections # Create a task to monitor for disconnections
disconnect_monitor_task = None disconnect_monitor_task = None
@@ -727,8 +743,34 @@ async def chat(
try: try:
msg = await request.receive() msg = await request.receive()
if msg["type"] == "http.disconnect": if msg["type"] == "http.disconnect":
logger.debug(f"User {user} disconnected from {common.client} client.") logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
cancellation_event.set() cancellation_event.set()
# ensure partial chat state saved on interrupt
# shield the save against task cancellation
if conversation:
await asyncio.shield(
save_to_conversation_log(
q,
chat_response="",
user=user,
meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
)
except Exception as e: except Exception as e:
logger.error(f"Error in disconnect monitor: {e}") logger.error(f"Error in disconnect monitor: {e}")
@@ -746,7 +788,6 @@ async def chat(
nonlocal ttft, train_of_thought nonlocal ttft, train_of_thought
event_delimiter = "␃🔚␗" event_delimiter = "␃🔚␗"
if cancellation_event.is_set(): if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
return return
try: try:
if event_type == ChatEvent.END_LLM_RESPONSE: if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -770,9 +811,6 @@ async def chat(
yield data yield data
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
if cancellation_event.is_set():
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
except Exception as e: except Exception as e:
if not cancellation_event.is_set(): if not cancellation_event.is_set():
logger.error( logger.error(
@@ -883,21 +921,53 @@ async def chat(
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log meta_log = conversation.conversation_log
researched_results = "" # If interrupt flag is set, wait for the previous turn to be saved before proceeding
online_results: Dict = dict() if interrupt_flag:
code_results: Dict = dict() max_wait_time = 20.0 # seconds
operator_results: Dict[str, str] = {} wait_interval = 0.3 # seconds
generated_asset_results: Dict = dict() wait_start = wait_current = time.time()
## Extract Document References while wait_current - wait_start < max_wait_time:
compiled_references: List[Any] = [] # Refresh conversation to check if interrupted message saved to DB
inferred_queries: List[Any] = [] conversation = await ConversationAdapters.aget_conversation_by_user(
file_filters = conversation.file_filters if conversation and conversation.file_filters else [] user,
attached_file_context = gather_raw_query_files(query_files) client_application=request.user.client_app,
conversation_id=conversation_id,
)
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
break
await asyncio.sleep(wait_interval)
wait_current = time.time()
generated_images: List[str] = [] if wait_current - wait_start >= max_wait_time:
generated_files: List[FileAttachment] = [] logger.warning(
generated_mermaidjs_diagram: str = None f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
program_execution_context: List[str] = [] )
# If interrupted message in DB
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
# Populate context from interrupted message
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
operator_results = last_message.operatorContext or {}
compiled_references = [ref.model_dump() for ref in last_message.context or []]
research_results = [
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
]
# Drop the interrupted message from conversation history
meta_log["chat"].pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
if conversation_commands == [ConversationCommand.Default]: if conversation_commands == [ConversationCommand.Default]:
try: try:
@@ -936,6 +1006,7 @@ async def chat(
return return
defiltered_query = defilter_query(q) defiltered_query = defilter_query(q)
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
if conversation_commands == [ConversationCommand.Research]: if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection( async for research_result in execute_information_collection(
@@ -943,12 +1014,13 @@ async def chat(
query=defiltered_query, query=defiltered_query,
conversation_id=conversation_id, conversation_id=conversation_id,
conversation_history=meta_log, conversation_history=meta_log,
previous_iterations=research_results,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
user_name=user_name, user_name=user_name,
location=location, location=location,
file_filters=conversation.file_filters if conversation else [], file_filters=file_filters,
query_files=attached_file_context, query_files=attached_file_context,
tracer=tracer, tracer=tracer,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
@@ -963,17 +1035,16 @@ async def chat(
compiled_references.extend(research_result.context) compiled_references.extend(research_result.context)
if research_result.operatorContext: if research_result.operatorContext:
operator_results.update(research_result.operatorContext) operator_results.update(research_result.operatorContext)
researched_results += research_result.summarizedResult research_results.append(research_result)
else: else:
yield research_result yield research_result
# researched_results = await extract_relevant_info(q, researched_results, agent) # researched_results = await extract_relevant_info(q, researched_results, agent)
if state.verbose > 1: if state.verbose > 1:
logger.debug(f"Researched Results: {researched_results}") logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if # Skip trying to summarize if
if ( if (
# summarization intent was inferred # summarization intent was inferred
@@ -1362,7 +1433,7 @@ async def chat(
# Check if the user has disconnected # Check if the user has disconnected
if cancellation_event.is_set(): if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.") logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
# Cancel the disconnect monitor task if it is still running # Cancel the disconnect monitor task if it is still running
await cancel_disconnect_monitor() await cancel_disconnect_monitor()
return return
@@ -1379,13 +1450,13 @@ async def chat(
online_results, online_results,
code_results, code_results,
operator_results, operator_results,
research_results,
inferred_queries, inferred_queries,
conversation_commands, conversation_commands,
user, user,
request.user.client_app, request.user.client_app,
location, location,
user_name, user_name,
researched_results,
uploaded_images, uploaded_images,
train_of_thought, train_of_thought,
attached_file_context, attached_file_context,

View File

@@ -72,7 +72,7 @@ async def update_chat_model(
if chat_model is None: if chat_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"})) return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"}))
if not subscribed and chat_model.price_tier != PriceTier.FREE: if not subscribed and chat_model.price_tier != PriceTier.FREE:
raise Response( return Response(
status_code=403, status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}), content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}),
) )
@@ -108,7 +108,7 @@ async def update_voice_model(
if voice_model is None: if voice_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"})) return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"}))
if not subscribed and voice_model.price_tier != PriceTier.FREE: if not subscribed and voice_model.price_tier != PriceTier.FREE:
raise Response( return Response(
status_code=403, status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}), content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}),
) )
@@ -143,7 +143,7 @@ async def update_paint_model(
if image_model is None: if image_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"})) return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"}))
if not subscribed and image_model.price_tier != PriceTier.FREE: if not subscribed and image_model.price_tier != PriceTier.FREE:
raise Response( return Response(
status_code=403, status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}), content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}),
) )

View File

@@ -94,6 +94,7 @@ from khoj.processor.conversation.openai.gpt import (
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ChatEvent, ChatEvent,
InformationCollectionIteration,
ResponseWithThought, ResponseWithThought,
clean_json, clean_json,
clean_mermaidjs, clean_mermaidjs,
@@ -1355,13 +1356,13 @@ async def agenerate_chat_response(
online_results: Dict[str, Dict] = {}, online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {}, code_results: Dict[str, Dict] = {},
operator_results: Dict[str, str] = {}, operator_results: Dict[str, str] = {},
research_results: List[InformationCollectionIteration] = [],
inferred_queries: List[str] = [], inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None, user: KhojUser = None,
client_application: ClientApplication = None, client_application: ClientApplication = None,
location_data: LocationData = None, location_data: LocationData = None,
user_name: Optional[str] = None, user_name: Optional[str] = None,
meta_research: str = "",
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
query_files: str = None, query_files: str = None,
@@ -1391,6 +1392,7 @@ async def agenerate_chat_response(
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,
operator_results=operator_results, operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
client_application=client_application, client_application=client_application,
conversation_id=str(conversation.id), conversation_id=str(conversation.id),
@@ -1405,8 +1407,10 @@ async def agenerate_chat_response(
query_to_run = q query_to_run = q
deepthought = False deepthought = False
if meta_research: if research_results:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>" compiled_research = "".join([r.summarizedResult for r in research_results if r.summarizedResult])
if compiled_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{compiled_research}\n</collected_research>"
compiled_references = [] compiled_references = []
online_results = {} online_results = {}
code_results = {} code_results = {}

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
import os import os
from copy import deepcopy
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Callable, Dict, List, Optional, Type from typing import Callable, Dict, List, Optional, Type
@@ -141,7 +142,7 @@ async def apick_next_tool(
query = f"[placeholder for user attached images]\n{query}" query = f"[placeholder for user attached images]\n{query}"
# Construct chat history with user and iteration history with researcher agent for context # Construct chat history with user and iteration history with researcher agent for context
previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration) previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history} iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history}
# Plan function execution for the next tool # Plan function execution for the next tool
@@ -212,6 +213,7 @@ async def execute_information_collection(
query: str, query: str,
conversation_id: str, conversation_id: str,
conversation_history: dict, conversation_history: dict,
previous_iterations: List[InformationCollectionIteration],
query_images: List[str], query_images: List[str],
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
@@ -227,11 +229,20 @@ async def execute_information_collection(
max_webpages_to_read = 1 max_webpages_to_read = 1
current_iteration = 0 current_iteration = 0
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
previous_iterations: List[InformationCollectionIteration] = []
# Incorporate previous partial research into current research chat history
research_conversation_history = deepcopy(conversation_history)
if current_iteration := len(previous_iterations) > 0:
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
research_conversation_history["chat"] = (
research_conversation_history.get("chat", []) + previous_iterations_history
)
while current_iteration < MAX_ITERATIONS: while current_iteration < MAX_ITERATIONS:
# Check for cancellation at the start of each iteration # Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set(): if cancellation_event and cancellation_event.is_set():
logger.debug(f"User {user} disconnected client. Research cancelled.") logger.debug(f"Research cancelled. User {user} disconnected client.")
break break
online_results: Dict = dict() online_results: Dict = dict()
@@ -243,7 +254,7 @@ async def execute_information_collection(
async for result in apick_next_tool( async for result in apick_next_tool(
query, query,
conversation_history, research_conversation_history,
user, user,
location, location,
user_name, user_name,

View File

@@ -168,6 +168,7 @@ class ChatRequestBody(BaseModel):
images: Optional[list[str]] = None images: Optional[list[str]] = None
files: Optional[list[FileAttachment]] = [] files: Optional[list[FileAttachment]] = []
create_new: Optional[bool] = False create_new: Optional[bool] = False
interrupt: Optional[bool] = False
class Entry: class Entry: