mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Handle passing interrupt messages from api to chat actors on server
This commit is contained in:
@@ -384,6 +384,7 @@ class ChatEvent(Enum):
|
||||
METADATA = "metadata"
|
||||
USAGE = "usage"
|
||||
END_RESPONSE = "end_response"
|
||||
INTERRUPT = "interrupt"
|
||||
|
||||
|
||||
def message_to_log(
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Callable, List, Optional
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.processor.conversation.utils import (
|
||||
AgentMessage,
|
||||
OperatorRun,
|
||||
construct_chat_history_for_operator,
|
||||
)
|
||||
@@ -22,7 +23,7 @@ from khoj.processor.operator.operator_environment_base import (
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
|
||||
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
|
||||
from khoj.routers.helpers import ChatEvent
|
||||
from khoj.routers.helpers import ChatEvent, get_message_from_queue
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
@@ -42,6 +43,7 @@ async def operate_environment(
|
||||
agent: Agent = None,
|
||||
query_files: str = None, # TODO: Handle query files
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
response, user_input_message = None, None
|
||||
@@ -140,6 +142,14 @@ async def operate_environment(
|
||||
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
|
||||
break
|
||||
|
||||
# Add interrupt query to current operator run
|
||||
if interrupt_query := get_message_from_queue(interrupt_queue):
|
||||
# Add the interrupt query as a new user message to the research conversation history
|
||||
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
|
||||
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))
|
||||
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
|
||||
yield result
|
||||
|
||||
iterations += 1
|
||||
|
||||
# 1. Get current environment state
|
||||
|
||||
@@ -672,6 +672,7 @@ async def event_generator(
|
||||
common: CommonQueryParams,
|
||||
headers: Headers,
|
||||
request_obj: Request | WebSocket,
|
||||
interrupt_queue: asyncio.Queue = None,
|
||||
):
|
||||
# Access the parameters from the body
|
||||
q = body.q
|
||||
@@ -688,7 +689,6 @@ async def event_generator(
|
||||
timezone = body.timezone
|
||||
raw_images = body.images
|
||||
raw_query_files = body.files
|
||||
interrupt_flag = body.interrupt
|
||||
|
||||
start_time = time.perf_counter()
|
||||
ttft = None
|
||||
@@ -955,34 +955,6 @@ async def event_generator(
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
chat_history = conversation.messages
|
||||
|
||||
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
|
||||
if interrupt_flag:
|
||||
max_wait_time = 20.0 # seconds
|
||||
wait_interval = 0.3 # seconds
|
||||
wait_start = wait_current = time.time()
|
||||
while wait_current - wait_start < max_wait_time:
|
||||
# Refresh conversation to check if interrupted message saved to DB
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user,
|
||||
client_application=user_scope.client_app,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if (
|
||||
conversation
|
||||
and conversation.messages
|
||||
and conversation.messages[-1].by == "khoj"
|
||||
and not conversation.messages[-1].message
|
||||
):
|
||||
logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
|
||||
break
|
||||
await asyncio.sleep(wait_interval)
|
||||
wait_current = time.time()
|
||||
|
||||
if wait_current - wait_start >= max_wait_time:
|
||||
logger.warning(
|
||||
f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
|
||||
)
|
||||
|
||||
# If interrupted message in DB
|
||||
if (
|
||||
conversation
|
||||
@@ -1061,6 +1033,7 @@ async def event_generator(
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=interrupt_queue,
|
||||
):
|
||||
if isinstance(research_result, ResearchIteration):
|
||||
if research_result.summarizedResult:
|
||||
@@ -1491,13 +1464,25 @@ async def chat_ws(
|
||||
)
|
||||
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
|
||||
|
||||
# Shared interrupt queue for communicating interrupts to ongoing research
|
||||
interrupt_queue: asyncio.Queue = asyncio.Queue()
|
||||
current_task = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
# Handle regular chat messages
|
||||
# Check if this is an interrupt message
|
||||
if data.get("type") == "interrupt":
|
||||
if current_task and not current_task.done():
|
||||
# Send interrupt signal to the ongoing task
|
||||
await interrupt_queue.put(data.get("query", ""))
|
||||
logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}")
|
||||
await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"}))
|
||||
else:
|
||||
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
|
||||
continue
|
||||
|
||||
# Handle regular chat messages - ensure data has required fields
|
||||
if "q" not in data:
|
||||
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
|
||||
@@ -1523,7 +1508,7 @@ async def chat_ws(
|
||||
pass
|
||||
|
||||
# Create a new task for processing the chat request
|
||||
current_task = asyncio.create_task(process_chat_request(websocket, body, common))
|
||||
current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue))
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
|
||||
@@ -1540,6 +1525,7 @@ async def process_chat_request(
|
||||
websocket: WebSocket,
|
||||
body: ChatRequestBody,
|
||||
common: CommonQueryParams,
|
||||
interrupt_queue: asyncio.Queue,
|
||||
):
|
||||
"""Process a single chat request with interrupt support"""
|
||||
try:
|
||||
@@ -1550,6 +1536,7 @@ async def process_chat_request(
|
||||
common,
|
||||
websocket.headers,
|
||||
websocket,
|
||||
interrupt_queue,
|
||||
)
|
||||
async for event in response_iterator:
|
||||
await websocket.send_text(event)
|
||||
|
||||
@@ -2600,6 +2600,17 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
||||
}
|
||||
|
||||
|
||||
def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]:
|
||||
"""Get any message in queue if available."""
|
||||
if not queue:
|
||||
return None
|
||||
try:
|
||||
# Non-blocking check for message in the queue
|
||||
return queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return None
|
||||
|
||||
|
||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
is_active = has_required_scope(request, ["premium"])
|
||||
|
||||
@@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import (
|
||||
ResearchIteration,
|
||||
ToolCall,
|
||||
construct_iteration_history,
|
||||
construct_structured_message,
|
||||
construct_tool_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
@@ -24,6 +25,7 @@ from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
generate_summary_from_files,
|
||||
get_message_from_queue,
|
||||
grep_files,
|
||||
list_files,
|
||||
search_documents,
|
||||
@@ -74,7 +76,7 @@ async def apick_next_tool(
|
||||
):
|
||||
previous_iteration = previous_iterations[-1]
|
||||
yield ResearchIteration(
|
||||
query=query,
|
||||
query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore
|
||||
context=previous_iteration.context,
|
||||
onlineContext=previous_iteration.onlineContext,
|
||||
codeContext=previous_iteration.codeContext,
|
||||
@@ -221,6 +223,7 @@ async def research(
|
||||
tracer: dict = {},
|
||||
query_files: str = None,
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
):
|
||||
max_document_searches = 7
|
||||
max_online_searches = 3
|
||||
@@ -241,6 +244,22 @@ async def research(
|
||||
logger.debug(f"Research cancelled. User {user} disconnected client.")
|
||||
break
|
||||
|
||||
# Update the query for the current research iteration
|
||||
if interrupt_query := get_message_from_queue(interrupt_queue):
|
||||
# Add the interrupt query as a new user message to the research conversation history
|
||||
logger.info(
|
||||
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"
|
||||
)
|
||||
previous_iterations_history = construct_iteration_history(
|
||||
previous_iterations, query, query_images, query_files
|
||||
)
|
||||
research_conversation_history += previous_iterations_history
|
||||
query = interrupt_query
|
||||
previous_iterations = []
|
||||
|
||||
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
|
||||
yield result
|
||||
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
document_results: List[Dict[str, str]] = []
|
||||
@@ -428,6 +447,7 @@ async def research(
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=interrupt_queue,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
||||
@@ -168,7 +168,6 @@ class ChatRequestBody(BaseModel):
|
||||
images: Optional[list[str]] = None
|
||||
files: Optional[list[FileAttachment]] = []
|
||||
create_new: Optional[bool] = False
|
||||
interrupt: Optional[bool] = False
|
||||
|
||||
|
||||
class Entry:
|
||||
|
||||
Reference in New Issue
Block a user