diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 011e8045..b4dd5e9c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -384,6 +384,7 @@ class ChatEvent(Enum): METADATA = "metadata" USAGE = "usage" END_RESPONSE = "end_response" + INTERRUPT = "interrupt" def message_to_log( diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 9b4ad80f..c07dec90 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -7,6 +7,7 @@ from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation.utils import ( + AgentMessage, OperatorRun, construct_chat_history_for_operator, ) @@ -22,7 +23,7 @@ from khoj.processor.operator.operator_environment_base import ( ) from khoj.processor.operator.operator_environment_browser import BrowserEnvironment from khoj.processor.operator.operator_environment_computer import ComputerEnvironment -from khoj.routers.helpers import ChatEvent +from khoj.routers.helpers import ChatEvent, get_message_from_queue from khoj.utils.helpers import timer from khoj.utils.rawconfig import LocationData @@ -42,6 +43,7 @@ async def operate_environment( agent: Agent = None, query_files: str = None, # TODO: Handle query files cancellation_event: Optional[asyncio.Event] = None, + interrupt_queue: Optional[asyncio.Queue] = None, tracer: dict = {}, ): response, user_input_message = None, None @@ -140,6 +142,14 @@ async def operate_environment( logger.debug(f"{environment_type.value} operator cancelled by client disconnect") break + # Add interrupt query to current operator run + if interrupt_query := get_message_from_queue(interrupt_queue): + # Add the interrupt query as a new user message to the research conversation history + logger.info(f"Continuing operator run with the new instruction: {interrupt_query}") + operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query)) + async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"): + yield result + iterations += 1 # 1. Get current environment state diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 04b50bfd..e602ab2c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -672,6 +672,7 @@ async def event_generator( common: CommonQueryParams, headers: Headers, request_obj: Request | WebSocket, + interrupt_queue: asyncio.Queue = None, ): # Access the parameters from the body q = body.q @@ -688,7 +689,6 @@ async def event_generator( timezone = body.timezone raw_images = body.images raw_query_files = body.files - interrupt_flag = body.interrupt start_time = time.perf_counter() ttft = None @@ -955,34 +955,6 @@ async def event_generator( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") chat_history = conversation.messages - # If interrupt flag is set, wait for the previous turn to be saved before proceeding - if interrupt_flag: - max_wait_time = 20.0 # seconds - wait_interval = 0.3 # seconds - wait_start = wait_current = time.time() - while wait_current - wait_start < max_wait_time: - # Refresh conversation to check if interrupted message saved to DB - conversation = await ConversationAdapters.aget_conversation_by_user( - user, - client_application=user_scope.client_app, - conversation_id=conversation_id, - ) - if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message - ): - logger.info(f"Detected interrupted message save to conversation {conversation_id}.") - break - await asyncio.sleep(wait_interval) - wait_current = time.time() - - if wait_current - wait_start >= max_wait_time: - logger.warning( - f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context." - ) - # If interrupted message in DB if ( conversation @@ -1061,6 +1033,7 @@ async def event_generator( query_files=attached_file_context, tracer=tracer, cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, ): if isinstance(research_result, ResearchIteration): if research_result.summarizedResult: @@ -1491,13 +1464,25 @@ async def chat_ws( ) image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20) + # Shared interrupt queue for communicating interrupts to ongoing research + interrupt_queue: asyncio.Queue = asyncio.Queue() current_task = None try: while True: data = await websocket.receive_json() - # Handle regular chat messages + # Check if this is an interrupt message + if data.get("type") == "interrupt": + if current_task and not current_task.done(): + # Send interrupt signal to the ongoing task + await interrupt_queue.put(data.get("query", "")) + logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}") + await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"})) + else: + logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}") + continue + # Handle regular chat messages - ensure data has required fields if "q" not in data: await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"})) @@ -1523,7 +1508,7 @@ async def chat_ws( pass # Create a new task for processing the chat request - current_task = asyncio.create_task(process_chat_request(websocket, body, common)) + current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue)) except WebSocketDisconnect: logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}") @@ -1540,6 +1525,7 @@ async def process_chat_request( websocket: WebSocket, body: ChatRequestBody, common: CommonQueryParams, + interrupt_queue: asyncio.Queue, ): """Process a single chat request with interrupt support""" try: @@ -1550,6 +1536,7 @@ async def process_chat_request( common, websocket.headers, websocket, + interrupt_queue, ) async for event in response_iterator: await websocket.send_text(event) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 9bfdb155..41189740 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -2600,6 +2600,17 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict } +def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]: + """Get any message in queue if available.""" + if not queue: + return None + try: + # Non-blocking check for message in the queue + return queue.get_nowait() + except asyncio.QueueEmpty: + return None + + def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): user_picture = request.session.get("user", {}).get("picture") is_active = has_required_scope(request, ["premium"]) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 5e476d4b..050fd3f8 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import ( ResearchIteration, ToolCall, construct_iteration_history, + construct_structured_message, construct_tool_chat_history, load_complex_json, ) @@ -24,6 +25,7 @@ from khoj.processor.tools.run_code import run_code from khoj.routers.helpers import ( ChatEvent, generate_summary_from_files, + get_message_from_queue, grep_files, list_files, search_documents, @@ -74,7 +76,7 @@ async def apick_next_tool( ): previous_iteration = previous_iterations[-1] yield ResearchIteration( - query=query, + query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore context=previous_iteration.context, onlineContext=previous_iteration.onlineContext, codeContext=previous_iteration.codeContext, @@ -221,6 +223,7 @@ async def research( tracer: dict = {}, query_files: str = None, cancellation_event: Optional[asyncio.Event] = None, + interrupt_queue: Optional[asyncio.Queue] = None, ): max_document_searches = 7 max_online_searches = 3 @@ -241,6 +244,22 @@ async def research( logger.debug(f"Research cancelled. User {user} disconnected client.") break + # Update the query for the current research iteration + if interrupt_query := get_message_from_queue(interrupt_queue): + # Add the interrupt query as a new user message to the research conversation history + logger.info( + f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}" + ) + previous_iterations_history = construct_iteration_history( + previous_iterations, query, query_images, query_files + ) + research_conversation_history += previous_iterations_history + query = interrupt_query + previous_iterations = [] + + async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"): + yield result + online_results: Dict = dict() code_results: Dict = dict() document_results: List[Dict[str, str]] = [] @@ -428,6 +447,7 @@ async def research( agent=agent, query_files=query_files, cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index e3662db5..df7f3334 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -168,7 +168,6 @@ class ChatRequestBody(BaseModel): images: Optional[list[str]] = None files: Optional[list[FileAttachment]] = [] create_new: Optional[bool] = False - interrupt: Optional[bool] = False class Entry: