From df9ab51fd0fc06bbd185e1623bf74c0b60784a27 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 20 May 2025 15:31:33 -0700 Subject: [PATCH 1/9] Track research results as iteration list instead of iteration summaries --- src/khoj/routers/api_chat.py | 8 ++++---- src/khoj/routers/helpers.py | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index fed6a559..7b78063d 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -883,7 +883,7 @@ async def chat( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") meta_log = conversation.conversation_log - researched_results = "" + research_results: List[InformationCollectionIteration] = [] online_results: Dict = dict() code_results: Dict = dict() operator_results: Dict[str, str] = {} @@ -963,14 +963,14 @@ async def chat( compiled_references.extend(research_result.context) if research_result.operatorContext: operator_results.update(research_result.operatorContext) - researched_results += research_result.summarizedResult + research_results.append(research_result) else: yield research_result # researched_results = await extract_relevant_info(q, researched_results, agent) if state.verbose > 1: - logger.debug(f"Researched Results: {researched_results}") + logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}') used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] file_filters = conversation.file_filters if conversation else [] @@ -1379,13 +1379,13 @@ async def chat( online_results, code_results, operator_results, + research_results, inferred_queries, conversation_commands, user, request.user.client_app, location, user_name, - researched_results, uploaded_images, train_of_thought, attached_file_context, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 38cf9174..a91f51a0 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -94,6 +94,7 @@ from khoj.processor.conversation.openai.gpt import ( ) from khoj.processor.conversation.utils import ( ChatEvent, + InformationCollectionIteration, ResponseWithThought, clean_json, clean_mermaidjs, @@ -1355,13 +1356,13 @@ async def agenerate_chat_response( online_results: Dict[str, Dict] = {}, code_results: Dict[str, Dict] = {}, operator_results: Dict[str, str] = {}, + research_results: List[InformationCollectionIteration] = [], inferred_queries: List[str] = [], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], user: KhojUser = None, client_application: ClientApplication = None, location_data: LocationData = None, user_name: Optional[str] = None, - meta_research: str = "", query_images: Optional[List[str]] = None, train_of_thought: List[Any] = [], query_files: str = None, @@ -1405,8 +1406,10 @@ async def agenerate_chat_response( query_to_run = q deepthought = False - if meta_research: - query_to_run = f"{q}\n\n{meta_research}\n" + if research_results: + compiled_research = "".join([r.summarizedResult for r in research_results if r.summarizedResult]) + if compiled_research: + query_to_run = f"{q}\n\n{compiled_research}\n" compiled_references = [] online_results = {} code_results = {} From 98b56316e488506616b0580a9afff900bdf3e125 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 21 May 2025 22:57:56 -0700 Subject: [PATCH 2/9] Support constructing chat message as a list of dictionaries Research mode recently started passing iteration as list of message content dicts. This change extends to storing it as is in DB. --- src/khoj/processor/conversation/utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index ba978429..45d8c6cf 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -121,7 +121,7 @@ def construct_iteration_history( index=idx + 1, ) - previous_iterations_history.append(iteration_data) + previous_iterations_history.append({"type": "text", "text": iteration_data}) return ( [ @@ -341,7 +341,7 @@ Khoj: "{chat_response}" def construct_structured_message( - message: list[str] | str, + message: list[dict] | str, images: list[str], model_type: str, vision_enabled: bool, @@ -355,11 +355,9 @@ def construct_structured_message( ChatModel.ModelType.GOOGLE, ChatModel.ModelType.ANTHROPIC, ]: - message = [message] if isinstance(message, str) else message - - constructed_messages: List[dict[str, Any]] = [ - {"type": "text", "text": message_part} for message_part in message - ] + constructed_messages: List[dict[str, Any]] = ( + [{"type": "text", "text": message}] if isinstance(message, str) else message + ) if not is_none_or_empty(attached_file_context): constructed_messages.append({"type": "text", "text": attached_file_context}) @@ -368,6 +366,7 @@ def construct_structured_message( constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) return constructed_messages + message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message) if not is_none_or_empty(attached_file_context): return f"{attached_file_context}\n\n{message}" From 02ee4e90a25e9a1bdb9d6292d17d77a08f1ccf38 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 21 May 2025 23:04:37 -0700 Subject: [PATCH 3/9] Pass doc/web/code/operator context as list[dict] of message content --- src/khoj/processor/conversation/utils.py | 43 +++++++++++++++--------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 45d8c6cf..bf7be40d 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -420,7 +420,7 @@ def generate_chatml_messages_with_context( # Extract Chat History for Context chatml_messages: List[ChatMessage] = [] for chat in conversation_log.get("chat", []): - message_context = "" + message_context = [] message_attached_files = "" generated_assets = {} @@ -432,16 +432,6 @@ def generate_chatml_messages_with_context( if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): chat_message = chat["intent"].get("inferred-queries")[0] - if not is_none_or_empty(chat.get("context")): - references = "\n\n".join( - { - f"# File: {item['file']}\n## {item['compiled']}\n" - for item in chat.get("context") or [] - if isinstance(item, dict) - } - ) - message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" - if chat.get("queryFiles"): raw_query_files = chat.get("queryFiles") query_files_dict = dict() @@ -452,15 +442,38 @@ def generate_chatml_messages_with_context( chatml_messages.append(ChatMessage(content=message_attached_files, role=role)) if not is_none_or_empty(chat.get("onlineContext")): - message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" + message_context += [ + { + "type": "text", + "text": f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}", + } + ] if not is_none_or_empty(chat.get("codeContext")): - message_context += f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}" + message_context += [ + { + "type": "text", + "text": f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}", + } + ] if not is_none_or_empty(chat.get("operatorContext")): - message_context += ( - f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}" + message_context += [ + { + "type": "text", + "text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}", + } + ] + + if not is_none_or_empty(chat.get("context")): + references = "\n\n".join( + { + f"# File: {item['file']}\n## {item['compiled']}\n" + for item in chat.get("context") or [] + if isinstance(item, dict) + } ) + message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}] if not is_none_or_empty(message_context): reconstructed_context_message = ChatMessage(content=message_context, role="user") From a83c36fa0590f94570ea4f74277f6092846f2591 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 23 May 2025 02:36:58 -0700 Subject: [PATCH 4/9] Validate operator, research, context.query fields of ChatMessage - Track operator, research context in ChatMessage - Track query field in (document) context field of ChatMessage This allows validating chat message before inserting into DB --- src/khoj/database/models/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index bd49aa8c..34538df0 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) class Context(PydanticBaseModel): compiled: str file: str + query: str class CodeContextFile(PydanticBaseModel): @@ -105,6 +106,8 @@ class ChatMessage(PydanticBaseModel): context: List[Context] = [] onlineContext: Dict[str, OnlineContext] = {} codeContext: Dict[str, CodeContextData] = {} + researchContext: Optional[List] = None + operatorContext: Optional[Dict[str, str]] = None created: str images: Optional[List[str]] = None queryFiles: Optional[List[Dict]] = None From 3cd6e1a9a60e7c08c6ec4064f656844e75c43f74 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 20 May 2025 16:00:40 -0700 Subject: [PATCH 5/9] Save and restore research from partial state --- src/khoj/processor/conversation/utils.py | 31 +++++---- src/khoj/routers/api_chat.py | 88 +++++++++++++++++------- src/khoj/routers/helpers.py | 1 + src/khoj/routers/research.py | 19 +++-- 4 files changed, 97 insertions(+), 42 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index bf7be40d..36aa001d 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -110,9 +110,12 @@ class InformationCollectionIteration: def construct_iteration_history( - query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str + previous_iterations: List[InformationCollectionIteration], + previous_iteration_prompt: str, + query: str = None, ) -> list[dict]: - previous_iterations_history = [] + iteration_history: list[dict] = [] + previous_iteration_messages: list[dict] = [] for idx, iteration in enumerate(previous_iterations): iteration_data = previous_iteration_prompt.format( tool=iteration.tool, @@ -121,23 +124,19 @@ def construct_iteration_history( index=idx + 1, ) - previous_iterations_history.append({"type": "text", "text": iteration_data}) + previous_iteration_messages.append({"type": "text", "text": iteration_data}) - return ( - [ - { - "by": "you", - "message": query, - }, + if previous_iteration_messages: + if query: + iteration_history.append({"by": "you", "message": query}) + iteration_history.append( { "by": "khoj", "intent": {"type": "remember", "query": query}, - "message": previous_iterations_history, - }, - ] - if previous_iterations_history - else [] - ) + "message": previous_iteration_messages, + } + ) + return iteration_history def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: @@ -285,6 +284,7 @@ async def save_to_conversation_log( generated_images: List[str] = [], raw_generated_files: List[FileAttachment] = [], generated_mermaidjs_diagram: str = None, + research_results: Optional[List[InformationCollectionIteration]] = None, train_of_thought: List[Any] = [], tracer: Dict[str, Any] = {}, ): @@ -302,6 +302,7 @@ async def save_to_conversation_log( "onlineContext": online_results, "codeContext": code_results, "operatorContext": operator_results, + "researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None, "automationId": automation_id, "trainOfThought": train_of_thought, "turnId": turn_id, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 7b78063d..bf72121c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -687,6 +687,7 @@ async def chat( start_time = time.perf_counter() ttft = None chat_metadata: dict = {} + conversation = None user: KhojUser = request.user.object is_subscribed = has_required_scope(request, ["premium"]) q = unquote(q) @@ -720,6 +721,20 @@ async def chat( for file in raw_query_files: query_files[file.name] = file.content + research_results: List[InformationCollectionIteration] = [] + online_results: Dict = dict() + code_results: Dict = dict() + operator_results: Dict[str, str] = {} + compiled_references: List[Any] = [] + inferred_queries: List[Any] = [] + attached_file_context = gather_raw_query_files(query_files) + + generated_images: List[str] = [] + generated_files: List[FileAttachment] = [] + generated_mermaidjs_diagram: str = None + generated_asset_results: Dict = dict() + program_execution_context: List[str] = [] + # Create a task to monitor for disconnections disconnect_monitor_task = None @@ -727,8 +742,34 @@ async def chat( try: msg = await request.receive() if msg["type"] == "http.disconnect": - logger.debug(f"User {user} disconnected from {common.client} client.") + logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.") cancellation_event.set() + # ensure partial chat state saved on interrupt + # shield the save against task cancellation + if conversation: + await asyncio.shield( + save_to_conversation_log( + q, + chat_response="", + user=user, + meta_log=meta_log, + compiled_references=compiled_references, + online_results=online_results, + code_results=code_results, + operator_results=operator_results, + research_results=research_results, + inferred_queries=inferred_queries, + client_application=request.user.client_app, + conversation_id=conversation_id, + query_images=uploaded_images, + train_of_thought=train_of_thought, + raw_query_files=raw_query_files, + generated_images=generated_images, + raw_generated_files=generated_asset_results, + generated_mermaidjs_diagram=generated_mermaidjs_diagram, + tracer=tracer, + ) + ) except Exception as e: logger.error(f"Error in disconnect monitor: {e}") @@ -746,7 +787,6 @@ async def chat( nonlocal ttft, train_of_thought event_delimiter = "␃🔚␗" if cancellation_event.is_set(): - logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.") return try: if event_type == ChatEvent.END_LLM_RESPONSE: @@ -770,9 +810,6 @@ async def chat( yield data elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) - except asyncio.CancelledError as e: - if cancellation_event.is_set(): - logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.") except Exception as e: if not cancellation_event.is_set(): logger.error( @@ -883,21 +920,25 @@ async def chat( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") meta_log = conversation.conversation_log - research_results: List[InformationCollectionIteration] = [] - online_results: Dict = dict() - code_results: Dict = dict() - operator_results: Dict[str, str] = {} - generated_asset_results: Dict = dict() - ## Extract Document References - compiled_references: List[Any] = [] - inferred_queries: List[Any] = [] - file_filters = conversation.file_filters if conversation and conversation.file_filters else [] - attached_file_context = gather_raw_query_files(query_files) - - generated_images: List[str] = [] - generated_files: List[FileAttachment] = [] - generated_mermaidjs_diagram: str = None - program_execution_context: List[str] = [] + # If interrupted message in DB + if ( + conversation + and conversation.messages + and conversation.messages[-1].by == "khoj" + and not conversation.messages[-1].message + ): + # Populate context from interrupted message + last_message = conversation.messages[-1] + online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} + code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} + operator_results = last_message.operatorContext or {} + compiled_references = [ref.model_dump() for ref in last_message.context or []] + research_results = [ + InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or [] + ] + # Drop the interrupted message from conversation history + meta_log["chat"].pop() + logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") if conversation_commands == [ConversationCommand.Default]: try: @@ -936,6 +977,7 @@ async def chat( return defiltered_query = defilter_query(q) + file_filters = conversation.file_filters if conversation and conversation.file_filters else [] if conversation_commands == [ConversationCommand.Research]: async for research_result in execute_information_collection( @@ -943,12 +985,13 @@ async def chat( query=defiltered_query, conversation_id=conversation_id, conversation_history=meta_log, + previous_iterations=research_results, query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), user_name=user_name, location=location, - file_filters=conversation.file_filters if conversation else [], + file_filters=file_filters, query_files=attached_file_context, tracer=tracer, cancellation_event=cancellation_event, @@ -973,7 +1016,6 @@ async def chat( logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}') used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - file_filters = conversation.file_filters if conversation else [] # Skip trying to summarize if if ( # summarization intent was inferred @@ -1362,7 +1404,7 @@ async def chat( # Check if the user has disconnected if cancellation_event.is_set(): - logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.") + logger.debug(f"Stopping LLM response to user {user} on {common.client} client.") # Cancel the disconnect monitor task if it is still running await cancel_disconnect_monitor() return diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a91f51a0..c1ddb82d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1392,6 +1392,7 @@ async def agenerate_chat_response( online_results=online_results, code_results=code_results, operator_results=operator_results, + research_results=research_results, inferred_queries=inferred_queries, client_application=client_application, conversation_id=str(conversation.id), diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 2f8157b4..93efee1f 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -1,6 +1,7 @@ import asyncio import logging import os +from copy import deepcopy from datetime import datetime from enum import Enum from typing import Callable, Dict, List, Optional, Type @@ -141,7 +142,7 @@ async def apick_next_tool( query = f"[placeholder for user attached images]\n{query}" # Construct chat history with user and iteration history with researcher agent for context - previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration) + previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query) iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history} # Plan function execution for the next tool @@ -212,6 +213,7 @@ async def execute_information_collection( query: str, conversation_id: str, conversation_history: dict, + previous_iterations: List[InformationCollectionIteration], query_images: List[str], agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -227,11 +229,20 @@ async def execute_information_collection( max_webpages_to_read = 1 current_iteration = 0 MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) - previous_iterations: List[InformationCollectionIteration] = [] + + # Incorporate previous partial research into current research chat history + research_conversation_history = deepcopy(conversation_history) + if current_iteration := len(previous_iterations) > 0: + logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.") + previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) + research_conversation_history["chat"] = ( + research_conversation_history.get("chat", []) + previous_iterations_history + ) + while current_iteration < MAX_ITERATIONS: # Check for cancellation at the start of each iteration if cancellation_event and cancellation_event.is_set(): - logger.debug(f"User {user} disconnected client. Research cancelled.") + logger.debug(f"Research cancelled. User {user} disconnected client.") break online_results: Dict = dict() @@ -243,7 +254,7 @@ async def execute_information_collection( async for result in apick_next_tool( query, - conversation_history, + research_conversation_history, user, location, user_name, From 2b7dd7401bde772af524841f2c8b9043c73473d4 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 22 May 2025 16:26:07 -0700 Subject: [PATCH 6/9] Continue interrupt queries only after previous query written to DB --- src/khoj/routers/api_chat.py | 29 +++++++++++++++++++++++++++++ src/khoj/utils/rawconfig.py | 1 + 2 files changed, 30 insertions(+) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index bf72121c..15023388 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -682,6 +682,7 @@ async def chat( timezone = body.timezone raw_images = body.images raw_query_files = body.files + interrupt_flag = body.interrupt async def event_generator(q: str, images: list[str]): start_time = time.perf_counter() @@ -920,6 +921,34 @@ async def chat( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") meta_log = conversation.conversation_log + # If interrupt flag is set, wait for the previous turn to be saved before proceeding + if interrupt_flag: + max_wait_time = 20.0 # seconds + wait_interval = 0.3 # seconds + wait_start = wait_current = time.time() + while wait_current - wait_start < max_wait_time: + # Refresh conversation to check if interrupted message saved to DB + conversation = await ConversationAdapters.aget_conversation_by_user( + user, + client_application=request.user.client_app, + conversation_id=conversation_id, + ) + if ( + conversation + and conversation.messages + and conversation.messages[-1].by == "khoj" + and not conversation.messages[-1].message + ): + logger.info(f"Detected interrupted message save to conversation {conversation_id}.") + break + await asyncio.sleep(wait_interval) + wait_current = time.time() + + if wait_current - wait_start >= max_wait_time: + logger.warning( + f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context." + ) + # If interrupted message in DB if ( conversation diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index fc637ea5..e0248a66 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -168,6 +168,7 @@ class ChatRequestBody(BaseModel): images: Optional[list[str]] = None files: Optional[list[FileAttachment]] = [] create_new: Optional[bool] = False + interrupt: Optional[bool] = False class Entry: From 6cb512d9cfcaaad920da98b18cf9b48905bdbf70 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 21 May 2025 13:28:12 -0700 Subject: [PATCH 7/9] Support natural interrupt and send query behavior from web app - Just send your new query. If a query was running previously it'd be interrupted and new query would start processing. This improves on the previous 2 click interrupt and send ux. - Utilizes partial research for interrupted query, so you can now redirect khoj's research direction. This is useful if you need to share more details, change khoj's research direction in anyway or complete research. Khoj's train of thought can be helpful for this. --- src/interface/web/app/chat/page.tsx | 16 +++++++++++++--- .../components/chatInputArea/chatInputArea.tsx | 16 +++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index b672091d..ecd10a4f 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -49,6 +49,7 @@ interface ChatBodyDataProps { isChatSideBarOpen: boolean; setIsChatSideBarOpen: (open: boolean) => void; isActive?: boolean; + isParentProcessing?: boolean; } function ChatBodyData(props: ChatBodyDataProps) { @@ -166,7 +167,7 @@ function ChatBodyData(props: ChatBodyDataProps) { isLoggedIn={props.isLoggedIn} sendMessage={(message) => setMessage(message)} sendImage={(image) => setImages((prevImages) => [...prevImages, image])} - sendDisabled={processingMessage} + sendDisabled={props.isParentProcessing || false} chatOptionsData={props.chatOptionsData} conversationId={conversationId} isMobileWidth={props.isMobileWidth} @@ -203,6 +204,7 @@ export default function Chat() { const [abortMessageStreamController, setAbortMessageStreamController] = useState(null); const [triggeredAbort, setTriggeredAbort] = useState(false); + const [shouldSendWithInterrupt, setShouldSendWithInterrupt] = useState(false); const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || { locationData: { @@ -239,6 +241,7 @@ export default function Chat() { if (triggeredAbort) { abortMessageStreamController?.abort(); handleAbortedMessage(); + setShouldSendWithInterrupt(true); setTriggeredAbort(false); } }, [triggeredAbort]); @@ -335,18 +338,21 @@ export default function Chat() { currentMessage.completed = true; setMessages([...messages]); - setQueryToProcess(""); setProcessQuerySignal(false); } async function chat() { localStorage.removeItem("message"); - if (!queryToProcess || !conversationId) return; + if (!queryToProcess || !conversationId) { + setProcessQuerySignal(false); + return; + } const chatAPI = "/api/chat?client=web"; const chatAPIBody = { q: queryToProcess, conversation_id: conversationId, stream: true, + interrupt: shouldSendWithInterrupt, ...(locationData && { city: locationData.city, region: locationData.region, @@ -358,6 +364,9 @@ export default function Chat() { ...(uploadedFiles && { files: uploadedFiles }), }; + // Reset the flag after using it + setShouldSendWithInterrupt(false); + const response = await fetch(chatAPI, { method: "POST", headers: { @@ -481,6 +490,7 @@ export default function Chat() { isChatSideBarOpen={isChatSideBarOpen} setIsChatSideBarOpen={setIsChatSideBarOpen} isActive={authenticatedData?.is_active} + isParentProcessing={processQuerySignal} /> diff --git a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx index 5448d8ce..9c9a6214 100644 --- a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx +++ b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx @@ -195,6 +195,11 @@ export const ChatInputArea = forwardRef((pr return; } + // If currently processing, trigger abort first + if (props.sendDisabled) { + props.setTriggeredAbort(true); + } + let messageToSend = message.trim(); // Check if message starts with an explicit slash command const startsWithSlashCommand = @@ -657,7 +662,7 @@ export const ChatInputArea = forwardRef((pr