mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Handle passing interrupt messages from api to chat actors on server
This commit is contained in:
@@ -384,6 +384,7 @@ class ChatEvent(Enum):
|
|||||||
METADATA = "metadata"
|
METADATA = "metadata"
|
||||||
USAGE = "usage"
|
USAGE = "usage"
|
||||||
END_RESPONSE = "end_response"
|
END_RESPONSE = "end_response"
|
||||||
|
INTERRUPT = "interrupt"
|
||||||
|
|
||||||
|
|
||||||
def message_to_log(
|
def message_to_log(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Callable, List, Optional
|
|||||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
|
AgentMessage,
|
||||||
OperatorRun,
|
OperatorRun,
|
||||||
construct_chat_history_for_operator,
|
construct_chat_history_for_operator,
|
||||||
)
|
)
|
||||||
@@ -22,7 +23,7 @@ from khoj.processor.operator.operator_environment_base import (
|
|||||||
)
|
)
|
||||||
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
|
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
|
||||||
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
|
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
|
||||||
from khoj.routers.helpers import ChatEvent
|
from khoj.routers.helpers import ChatEvent, get_message_from_queue
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
@@ -42,6 +43,7 @@ async def operate_environment(
|
|||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_files: str = None, # TODO: Handle query files
|
query_files: str = None, # TODO: Handle query files
|
||||||
cancellation_event: Optional[asyncio.Event] = None,
|
cancellation_event: Optional[asyncio.Event] = None,
|
||||||
|
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
response, user_input_message = None, None
|
response, user_input_message = None, None
|
||||||
@@ -140,6 +142,14 @@ async def operate_environment(
|
|||||||
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
|
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Add interrupt query to current operator run
|
||||||
|
if interrupt_query := get_message_from_queue(interrupt_queue):
|
||||||
|
# Add the interrupt query as a new user message to the research conversation history
|
||||||
|
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
|
||||||
|
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))
|
||||||
|
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
|
||||||
|
yield result
|
||||||
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
|
|
||||||
# 1. Get current environment state
|
# 1. Get current environment state
|
||||||
|
|||||||
@@ -672,6 +672,7 @@ async def event_generator(
|
|||||||
common: CommonQueryParams,
|
common: CommonQueryParams,
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
request_obj: Request | WebSocket,
|
request_obj: Request | WebSocket,
|
||||||
|
interrupt_queue: asyncio.Queue = None,
|
||||||
):
|
):
|
||||||
# Access the parameters from the body
|
# Access the parameters from the body
|
||||||
q = body.q
|
q = body.q
|
||||||
@@ -688,7 +689,6 @@ async def event_generator(
|
|||||||
timezone = body.timezone
|
timezone = body.timezone
|
||||||
raw_images = body.images
|
raw_images = body.images
|
||||||
raw_query_files = body.files
|
raw_query_files = body.files
|
||||||
interrupt_flag = body.interrupt
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
ttft = None
|
ttft = None
|
||||||
@@ -955,34 +955,6 @@ async def event_generator(
|
|||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
chat_history = conversation.messages
|
chat_history = conversation.messages
|
||||||
|
|
||||||
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
|
|
||||||
if interrupt_flag:
|
|
||||||
max_wait_time = 20.0 # seconds
|
|
||||||
wait_interval = 0.3 # seconds
|
|
||||||
wait_start = wait_current = time.time()
|
|
||||||
while wait_current - wait_start < max_wait_time:
|
|
||||||
# Refresh conversation to check if interrupted message saved to DB
|
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
|
||||||
user,
|
|
||||||
client_application=user_scope.client_app,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
conversation
|
|
||||||
and conversation.messages
|
|
||||||
and conversation.messages[-1].by == "khoj"
|
|
||||||
and not conversation.messages[-1].message
|
|
||||||
):
|
|
||||||
logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
|
|
||||||
break
|
|
||||||
await asyncio.sleep(wait_interval)
|
|
||||||
wait_current = time.time()
|
|
||||||
|
|
||||||
if wait_current - wait_start >= max_wait_time:
|
|
||||||
logger.warning(
|
|
||||||
f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
|
|
||||||
)
|
|
||||||
|
|
||||||
# If interrupted message in DB
|
# If interrupted message in DB
|
||||||
if (
|
if (
|
||||||
conversation
|
conversation
|
||||||
@@ -1061,6 +1033,7 @@ async def event_generator(
|
|||||||
query_files=attached_file_context,
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
cancellation_event=cancellation_event,
|
cancellation_event=cancellation_event,
|
||||||
|
interrupt_queue=interrupt_queue,
|
||||||
):
|
):
|
||||||
if isinstance(research_result, ResearchIteration):
|
if isinstance(research_result, ResearchIteration):
|
||||||
if research_result.summarizedResult:
|
if research_result.summarizedResult:
|
||||||
@@ -1491,13 +1464,25 @@ async def chat_ws(
|
|||||||
)
|
)
|
||||||
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
|
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
|
||||||
|
|
||||||
|
# Shared interrupt queue for communicating interrupts to ongoing research
|
||||||
|
interrupt_queue: asyncio.Queue = asyncio.Queue()
|
||||||
current_task = None
|
current_task = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await websocket.receive_json()
|
data = await websocket.receive_json()
|
||||||
|
|
||||||
# Handle regular chat messages
|
# Check if this is an interrupt message
|
||||||
|
if data.get("type") == "interrupt":
|
||||||
|
if current_task and not current_task.done():
|
||||||
|
# Send interrupt signal to the ongoing task
|
||||||
|
await interrupt_queue.put(data.get("query", ""))
|
||||||
|
logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}")
|
||||||
|
await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"}))
|
||||||
|
else:
|
||||||
|
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Handle regular chat messages - ensure data has required fields
|
# Handle regular chat messages - ensure data has required fields
|
||||||
if "q" not in data:
|
if "q" not in data:
|
||||||
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
|
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
|
||||||
@@ -1523,7 +1508,7 @@ async def chat_ws(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Create a new task for processing the chat request
|
# Create a new task for processing the chat request
|
||||||
current_task = asyncio.create_task(process_chat_request(websocket, body, common))
|
current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue))
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
|
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
|
||||||
@@ -1540,6 +1525,7 @@ async def process_chat_request(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
body: ChatRequestBody,
|
body: ChatRequestBody,
|
||||||
common: CommonQueryParams,
|
common: CommonQueryParams,
|
||||||
|
interrupt_queue: asyncio.Queue,
|
||||||
):
|
):
|
||||||
"""Process a single chat request with interrupt support"""
|
"""Process a single chat request with interrupt support"""
|
||||||
try:
|
try:
|
||||||
@@ -1550,6 +1536,7 @@ async def process_chat_request(
|
|||||||
common,
|
common,
|
||||||
websocket.headers,
|
websocket.headers,
|
||||||
websocket,
|
websocket,
|
||||||
|
interrupt_queue,
|
||||||
)
|
)
|
||||||
async for event in response_iterator:
|
async for event in response_iterator:
|
||||||
await websocket.send_text(event)
|
await websocket.send_text(event)
|
||||||
|
|||||||
@@ -2600,6 +2600,17 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]:
|
||||||
|
"""Get any message in queue if available."""
|
||||||
|
if not queue:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
# Non-blocking check for message in the queue
|
||||||
|
return queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||||
user_picture = request.session.get("user", {}).get("picture")
|
user_picture = request.session.get("user", {}).get("picture")
|
||||||
is_active = has_required_scope(request, ["premium"])
|
is_active = has_required_scope(request, ["premium"])
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import (
|
|||||||
ResearchIteration,
|
ResearchIteration,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
construct_iteration_history,
|
construct_iteration_history,
|
||||||
|
construct_structured_message,
|
||||||
construct_tool_chat_history,
|
construct_tool_chat_history,
|
||||||
load_complex_json,
|
load_complex_json,
|
||||||
)
|
)
|
||||||
@@ -24,6 +25,7 @@ from khoj.processor.tools.run_code import run_code
|
|||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
generate_summary_from_files,
|
generate_summary_from_files,
|
||||||
|
get_message_from_queue,
|
||||||
grep_files,
|
grep_files,
|
||||||
list_files,
|
list_files,
|
||||||
search_documents,
|
search_documents,
|
||||||
@@ -74,7 +76,7 @@ async def apick_next_tool(
|
|||||||
):
|
):
|
||||||
previous_iteration = previous_iterations[-1]
|
previous_iteration = previous_iterations[-1]
|
||||||
yield ResearchIteration(
|
yield ResearchIteration(
|
||||||
query=query,
|
query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore
|
||||||
context=previous_iteration.context,
|
context=previous_iteration.context,
|
||||||
onlineContext=previous_iteration.onlineContext,
|
onlineContext=previous_iteration.onlineContext,
|
||||||
codeContext=previous_iteration.codeContext,
|
codeContext=previous_iteration.codeContext,
|
||||||
@@ -221,6 +223,7 @@ async def research(
|
|||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
cancellation_event: Optional[asyncio.Event] = None,
|
cancellation_event: Optional[asyncio.Event] = None,
|
||||||
|
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||||
):
|
):
|
||||||
max_document_searches = 7
|
max_document_searches = 7
|
||||||
max_online_searches = 3
|
max_online_searches = 3
|
||||||
@@ -241,6 +244,22 @@ async def research(
|
|||||||
logger.debug(f"Research cancelled. User {user} disconnected client.")
|
logger.debug(f"Research cancelled. User {user} disconnected client.")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Update the query for the current research iteration
|
||||||
|
if interrupt_query := get_message_from_queue(interrupt_queue):
|
||||||
|
# Add the interrupt query as a new user message to the research conversation history
|
||||||
|
logger.info(
|
||||||
|
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"
|
||||||
|
)
|
||||||
|
previous_iterations_history = construct_iteration_history(
|
||||||
|
previous_iterations, query, query_images, query_files
|
||||||
|
)
|
||||||
|
research_conversation_history += previous_iterations_history
|
||||||
|
query = interrupt_query
|
||||||
|
previous_iterations = []
|
||||||
|
|
||||||
|
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
|
||||||
|
yield result
|
||||||
|
|
||||||
online_results: Dict = dict()
|
online_results: Dict = dict()
|
||||||
code_results: Dict = dict()
|
code_results: Dict = dict()
|
||||||
document_results: List[Dict[str, str]] = []
|
document_results: List[Dict[str, str]] = []
|
||||||
@@ -428,6 +447,7 @@ async def research(
|
|||||||
agent=agent,
|
agent=agent,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
cancellation_event=cancellation_event,
|
cancellation_event=cancellation_event,
|
||||||
|
interrupt_queue=interrupt_queue,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
|||||||
@@ -168,7 +168,6 @@ class ChatRequestBody(BaseModel):
|
|||||||
images: Optional[list[str]] = None
|
images: Optional[list[str]] = None
|
||||||
files: Optional[list[FileAttachment]] = []
|
files: Optional[list[FileAttachment]] = []
|
||||||
create_new: Optional[bool] = False
|
create_new: Optional[bool] = False
|
||||||
interrupt: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
class Entry:
|
class Entry:
|
||||||
|
|||||||
Reference in New Issue
Block a user