Handle passing interrupt messages from api to chat actors on server

This commit is contained in:
Debanjum
2025-07-12 11:19:41 -07:00
parent 38dd85c91f
commit 9f0eff6541
6 changed files with 62 additions and 34 deletions

View File

@@ -384,6 +384,7 @@ class ChatEvent(Enum):
METADATA = "metadata" METADATA = "metadata"
USAGE = "usage" USAGE = "usage"
END_RESPONSE = "end_response" END_RESPONSE = "end_response"
INTERRUPT = "interrupt"
def message_to_log( def message_to_log(

View File

@@ -7,6 +7,7 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun, OperatorRun,
construct_chat_history_for_operator, construct_chat_history_for_operator,
) )
@@ -22,7 +23,7 @@ from khoj.processor.operator.operator_environment_base import (
) )
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
from khoj.routers.helpers import ChatEvent from khoj.routers.helpers import ChatEvent, get_message_from_queue
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@@ -42,6 +43,7 @@ async def operate_environment(
agent: Agent = None, agent: Agent = None,
query_files: str = None, # TODO: Handle query files query_files: str = None, # TODO: Handle query files
cancellation_event: Optional[asyncio.Event] = None, cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
tracer: dict = {}, tracer: dict = {},
): ):
response, user_input_message = None, None response, user_input_message = None, None
@@ -140,6 +142,14 @@ async def operate_environment(
logger.debug(f"{environment_type.value} operator cancelled by client disconnect") logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
break break
# Add interrupt query to current operator run
if interrupt_query := get_message_from_queue(interrupt_queue):
# Add the interrupt query as a new user message to the research conversation history
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result
iterations += 1 iterations += 1
# 1. Get current environment state # 1. Get current environment state

View File

@@ -672,6 +672,7 @@ async def event_generator(
common: CommonQueryParams, common: CommonQueryParams,
headers: Headers, headers: Headers,
request_obj: Request | WebSocket, request_obj: Request | WebSocket,
interrupt_queue: asyncio.Queue = None,
): ):
# Access the parameters from the body # Access the parameters from the body
q = body.q q = body.q
@@ -688,7 +689,6 @@ async def event_generator(
timezone = body.timezone timezone = body.timezone
raw_images = body.images raw_images = body.images
raw_query_files = body.files raw_query_files = body.files
interrupt_flag = body.interrupt
start_time = time.perf_counter() start_time = time.perf_counter()
ttft = None ttft = None
@@ -955,34 +955,6 @@ async def event_generator(
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_history = conversation.messages chat_history = conversation.messages
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
if interrupt_flag:
max_wait_time = 20.0 # seconds
wait_interval = 0.3 # seconds
wait_start = wait_current = time.time()
while wait_current - wait_start < max_wait_time:
# Refresh conversation to check if interrupted message saved to DB
conversation = await ConversationAdapters.aget_conversation_by_user(
user,
client_application=user_scope.client_app,
conversation_id=conversation_id,
)
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
break
await asyncio.sleep(wait_interval)
wait_current = time.time()
if wait_current - wait_start >= max_wait_time:
logger.warning(
f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
)
# If interrupted message in DB # If interrupted message in DB
if ( if (
conversation conversation
@@ -1061,6 +1033,7 @@ async def event_generator(
query_files=attached_file_context, query_files=attached_file_context,
tracer=tracer, tracer=tracer,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
): ):
if isinstance(research_result, ResearchIteration): if isinstance(research_result, ResearchIteration):
if research_result.summarizedResult: if research_result.summarizedResult:
@@ -1491,13 +1464,25 @@ async def chat_ws(
) )
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20) image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
# Shared interrupt queue for communicating interrupts to ongoing research
interrupt_queue: asyncio.Queue = asyncio.Queue()
current_task = None current_task = None
try: try:
while True: while True:
data = await websocket.receive_json() data = await websocket.receive_json()
# Handle regular chat messages # Check if this is an interrupt message
if data.get("type") == "interrupt":
if current_task and not current_task.done():
# Send interrupt signal to the ongoing task
await interrupt_queue.put(data.get("query", ""))
logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}")
await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"}))
else:
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
continue
# Handle regular chat messages - ensure data has required fields # Handle regular chat messages - ensure data has required fields
if "q" not in data: if "q" not in data:
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"})) await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
@@ -1523,7 +1508,7 @@ async def chat_ws(
pass pass
# Create a new task for processing the chat request # Create a new task for processing the chat request
current_task = asyncio.create_task(process_chat_request(websocket, body, common)) current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue))
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}") logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
@@ -1540,6 +1525,7 @@ async def process_chat_request(
websocket: WebSocket, websocket: WebSocket,
body: ChatRequestBody, body: ChatRequestBody,
common: CommonQueryParams, common: CommonQueryParams,
interrupt_queue: asyncio.Queue,
): ):
"""Process a single chat request with interrupt support""" """Process a single chat request with interrupt support"""
try: try:
@@ -1550,6 +1536,7 @@ async def process_chat_request(
common, common,
websocket.headers, websocket.headers,
websocket, websocket,
interrupt_queue,
) )
async for event in response_iterator: async for event in response_iterator:
await websocket.send_text(event) await websocket.send_text(event)

View File

@@ -2600,6 +2600,17 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
} }
def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]:
"""Get any message in queue if available."""
if not queue:
return None
try:
# Non-blocking check for message in the queue
return queue.get_nowait()
except asyncio.QueueEmpty:
return None
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
is_active = has_required_scope(request, ["premium"]) is_active = has_required_scope(request, ["premium"])

View File

@@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import (
ResearchIteration, ResearchIteration,
ToolCall, ToolCall,
construct_iteration_history, construct_iteration_history,
construct_structured_message,
construct_tool_chat_history, construct_tool_chat_history,
load_complex_json, load_complex_json,
) )
@@ -24,6 +25,7 @@ from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
generate_summary_from_files, generate_summary_from_files,
get_message_from_queue,
grep_files, grep_files,
list_files, list_files,
search_documents, search_documents,
@@ -74,7 +76,7 @@ async def apick_next_tool(
): ):
previous_iteration = previous_iterations[-1] previous_iteration = previous_iterations[-1]
yield ResearchIteration( yield ResearchIteration(
query=query, query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore
context=previous_iteration.context, context=previous_iteration.context,
onlineContext=previous_iteration.onlineContext, onlineContext=previous_iteration.onlineContext,
codeContext=previous_iteration.codeContext, codeContext=previous_iteration.codeContext,
@@ -221,6 +223,7 @@ async def research(
tracer: dict = {}, tracer: dict = {},
query_files: str = None, query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None, cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
): ):
max_document_searches = 7 max_document_searches = 7
max_online_searches = 3 max_online_searches = 3
@@ -241,6 +244,22 @@ async def research(
logger.debug(f"Research cancelled. User {user} disconnected client.") logger.debug(f"Research cancelled. User {user} disconnected client.")
break break
# Update the query for the current research iteration
if interrupt_query := get_message_from_queue(interrupt_queue):
# Add the interrupt query as a new user message to the research conversation history
logger.info(
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"
)
previous_iterations_history = construct_iteration_history(
previous_iterations, query, query_images, query_files
)
research_conversation_history += previous_iterations_history
query = interrupt_query
previous_iterations = []
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result
online_results: Dict = dict() online_results: Dict = dict()
code_results: Dict = dict() code_results: Dict = dict()
document_results: List[Dict[str, str]] = [] document_results: List[Dict[str, str]] = []
@@ -428,6 +447,7 @@ async def research(
agent=agent, agent=agent,
query_files=query_files, query_files=query_files,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:

View File

@@ -168,7 +168,6 @@ class ChatRequestBody(BaseModel):
images: Optional[list[str]] = None images: Optional[list[str]] = None
files: Optional[list[FileAttachment]] = [] files: Optional[list[FileAttachment]] = []
create_new: Optional[bool] = False create_new: Optional[bool] = False
interrupt: Optional[bool] = False
class Entry: class Entry: