Handle passing interrupt messages from api to chat actors on server

This commit is contained in:
Debanjum
2025-07-12 11:19:41 -07:00
parent 38dd85c91f
commit 9f0eff6541
6 changed files with 62 additions and 34 deletions

View File

@@ -384,6 +384,7 @@ class ChatEvent(Enum):
METADATA = "metadata"
USAGE = "usage"
END_RESPONSE = "end_response"
INTERRUPT = "interrupt"
def message_to_log(

View File

@@ -7,6 +7,7 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
construct_chat_history_for_operator,
)
@@ -22,7 +23,7 @@ from khoj.processor.operator.operator_environment_base import (
)
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
from khoj.routers.helpers import ChatEvent
from khoj.routers.helpers import ChatEvent, get_message_from_queue
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import LocationData
@@ -42,6 +43,7 @@ async def operate_environment(
agent: Agent = None,
query_files: str = None, # TODO: Handle query files
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
tracer: dict = {},
):
response, user_input_message = None, None
@@ -140,6 +142,14 @@ async def operate_environment(
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
break
# Add interrupt query to current operator run
if interrupt_query := get_message_from_queue(interrupt_queue):
# Add the interrupt query as a new user message to the research conversation history
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result
iterations += 1
# 1. Get current environment state

View File

@@ -672,6 +672,7 @@ async def event_generator(
common: CommonQueryParams,
headers: Headers,
request_obj: Request | WebSocket,
interrupt_queue: asyncio.Queue = None,
):
# Access the parameters from the body
q = body.q
@@ -688,7 +689,6 @@ async def event_generator(
timezone = body.timezone
raw_images = body.images
raw_query_files = body.files
interrupt_flag = body.interrupt
start_time = time.perf_counter()
ttft = None
@@ -955,34 +955,6 @@ async def event_generator(
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_history = conversation.messages
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
if interrupt_flag:
max_wait_time = 20.0 # seconds
wait_interval = 0.3 # seconds
wait_start = wait_current = time.time()
while wait_current - wait_start < max_wait_time:
# Refresh conversation to check if interrupted message saved to DB
conversation = await ConversationAdapters.aget_conversation_by_user(
user,
client_application=user_scope.client_app,
conversation_id=conversation_id,
)
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
break
await asyncio.sleep(wait_interval)
wait_current = time.time()
if wait_current - wait_start >= max_wait_time:
logger.warning(
f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
)
# If interrupted message in DB
if (
conversation
@@ -1061,6 +1033,7 @@ async def event_generator(
query_files=attached_file_context,
tracer=tracer,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
):
if isinstance(research_result, ResearchIteration):
if research_result.summarizedResult:
@@ -1491,13 +1464,25 @@ async def chat_ws(
)
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
# Shared interrupt queue for communicating interrupts to ongoing research
interrupt_queue: asyncio.Queue = asyncio.Queue()
current_task = None
try:
while True:
data = await websocket.receive_json()
# Handle regular chat messages
# Check if this is an interrupt message
if data.get("type") == "interrupt":
if current_task and not current_task.done():
# Send interrupt signal to the ongoing task
await interrupt_queue.put(data.get("query", ""))
logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}")
await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"}))
else:
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
continue
# Handle regular chat messages - ensure data has required fields
if "q" not in data:
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
@@ -1523,7 +1508,7 @@ async def chat_ws(
pass
# Create a new task for processing the chat request
current_task = asyncio.create_task(process_chat_request(websocket, body, common))
current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue))
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
@@ -1540,6 +1525,7 @@ async def process_chat_request(
websocket: WebSocket,
body: ChatRequestBody,
common: CommonQueryParams,
interrupt_queue: asyncio.Queue,
):
"""Process a single chat request with interrupt support"""
try:
@@ -1550,6 +1536,7 @@ async def process_chat_request(
common,
websocket.headers,
websocket,
interrupt_queue,
)
async for event in response_iterator:
await websocket.send_text(event)

View File

@@ -2600,6 +2600,17 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
}
def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]:
"""Get any message in queue if available."""
if not queue:
return None
try:
# Non-blocking check for message in the queue
return queue.get_nowait()
except asyncio.QueueEmpty:
return None
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
user_picture = request.session.get("user", {}).get("picture")
is_active = has_required_scope(request, ["premium"])

View File

@@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import (
ResearchIteration,
ToolCall,
construct_iteration_history,
construct_structured_message,
construct_tool_chat_history,
load_complex_json,
)
@@ -24,6 +25,7 @@ from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import (
ChatEvent,
generate_summary_from_files,
get_message_from_queue,
grep_files,
list_files,
search_documents,
@@ -74,7 +76,7 @@ async def apick_next_tool(
):
previous_iteration = previous_iterations[-1]
yield ResearchIteration(
query=query,
query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore
context=previous_iteration.context,
onlineContext=previous_iteration.onlineContext,
codeContext=previous_iteration.codeContext,
@@ -221,6 +223,7 @@ async def research(
tracer: dict = {},
query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
):
max_document_searches = 7
max_online_searches = 3
@@ -241,6 +244,22 @@ async def research(
logger.debug(f"Research cancelled. User {user} disconnected client.")
break
# Update the query for the current research iteration
if interrupt_query := get_message_from_queue(interrupt_queue):
# Add the interrupt query as a new user message to the research conversation history
logger.info(
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"
)
previous_iterations_history = construct_iteration_history(
previous_iterations, query, query_images, query_files
)
research_conversation_history += previous_iterations_history
query = interrupt_query
previous_iterations = []
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result
online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
@@ -428,6 +447,7 @@ async def research(
agent=agent,
query_files=query_files,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:

View File

@@ -168,7 +168,6 @@ class ChatRequestBody(BaseModel):
images: Optional[list[str]] = None
files: Optional[list[FileAttachment]] = []
create_new: Optional[bool] = False
interrupt: Optional[bool] = False
class Entry: