mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Add cancellation support to research mode via asyncio.Event
This commit is contained in:
@@ -241,8 +241,7 @@ export default function Chat() {
|
|||||||
handleAbortedMessage();
|
handleAbortedMessage();
|
||||||
setTriggeredAbort(false);
|
setTriggeredAbort(false);
|
||||||
}
|
}
|
||||||
}),
|
}, [triggeredAbort]);
|
||||||
[triggeredAbort];
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (queryToProcess) {
|
if (queryToProcess) {
|
||||||
|
|||||||
@@ -683,14 +683,13 @@ async def chat(
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
ttft = None
|
ttft = None
|
||||||
chat_metadata: dict = {}
|
chat_metadata: dict = {}
|
||||||
connection_alive = True
|
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
is_subscribed = has_required_scope(request, ["premium"])
|
is_subscribed = has_required_scope(request, ["premium"])
|
||||||
event_delimiter = "␃🔚␗"
|
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
train_of_thought = []
|
train_of_thought = []
|
||||||
nonlocal conversation_id
|
nonlocal conversation_id
|
||||||
nonlocal raw_query_files
|
nonlocal raw_query_files
|
||||||
|
cancellation_event = asyncio.Event()
|
||||||
|
|
||||||
tracer: dict = {
|
tracer: dict = {
|
||||||
"mid": turn_id,
|
"mid": turn_id,
|
||||||
@@ -717,11 +716,33 @@ async def chat(
|
|||||||
for file in raw_query_files:
|
for file in raw_query_files:
|
||||||
query_files[file.name] = file.content
|
query_files[file.name] = file.content
|
||||||
|
|
||||||
|
# Create a task to monitor for disconnections
|
||||||
|
disconnect_monitor_task = None
|
||||||
|
|
||||||
|
async def monitor_disconnection():
|
||||||
|
try:
|
||||||
|
msg = await request.receive()
|
||||||
|
if msg["type"] == "http.disconnect":
|
||||||
|
logger.debug(f"User {user} disconnected from {common.client} client.")
|
||||||
|
cancellation_event.set()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in disconnect monitor: {e}")
|
||||||
|
|
||||||
|
# Cancel the disconnect monitor task if it is still running
|
||||||
|
async def cancel_disconnect_monitor():
|
||||||
|
if disconnect_monitor_task and not disconnect_monitor_task.done():
|
||||||
|
logger.debug(f"Cancelling disconnect monitor task for user {user}")
|
||||||
|
disconnect_monitor_task.cancel()
|
||||||
|
try:
|
||||||
|
await disconnect_monitor_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
nonlocal connection_alive, ttft, train_of_thought
|
nonlocal ttft, train_of_thought
|
||||||
if not connection_alive or await request.is_disconnected():
|
event_delimiter = "␃🔚␗"
|
||||||
connection_alive = False
|
if cancellation_event.is_set():
|
||||||
logger.warning(f"User {user} disconnected from {common.client} client")
|
logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||||
@@ -746,17 +767,25 @@ async def chat(
|
|||||||
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
|
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
|
||||||
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
||||||
except asyncio.CancelledError as e:
|
except asyncio.CancelledError as e:
|
||||||
connection_alive = False
|
if cancellation_event.is_set():
|
||||||
logger.warn(f"User {user} disconnected from {common.client} client: {e}")
|
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
|
||||||
return
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
connection_alive = False
|
if not cancellation_event.is_set():
|
||||||
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
|
logger.error(
|
||||||
return
|
f"Failed to stream chat API response to {user} on {common.client}: {e}.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
yield event_delimiter
|
if not cancellation_event.is_set():
|
||||||
|
yield event_delimiter
|
||||||
|
# Cancel the disconnect monitor task if it is still running
|
||||||
|
if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
|
||||||
|
await cancel_disconnect_monitor()
|
||||||
|
|
||||||
async def send_llm_response(response: str, usage: dict = None):
|
async def send_llm_response(response: str, usage: dict = None):
|
||||||
|
# Check if the client is still connected
|
||||||
|
if cancellation_event.is_set():
|
||||||
|
return
|
||||||
# Send Chat Response
|
# Send Chat Response
|
||||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
@@ -797,6 +826,9 @@ async def chat(
|
|||||||
metadata=chat_metadata,
|
metadata=chat_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Start the disconnect monitor in the background
|
||||||
|
disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
|
||||||
|
|
||||||
if is_query_empty(q):
|
if is_query_empty(q):
|
||||||
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
|
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
@@ -914,6 +946,7 @@ async def chat(
|
|||||||
file_filters=conversation.file_filters if conversation else [],
|
file_filters=conversation.file_filters if conversation else [],
|
||||||
query_files=attached_file_context,
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
cancellation_event=cancellation_event,
|
||||||
):
|
):
|
||||||
if isinstance(research_result, InformationCollectionIteration):
|
if isinstance(research_result, InformationCollectionIteration):
|
||||||
if research_result.summarizedResult:
|
if research_result.summarizedResult:
|
||||||
@@ -1288,6 +1321,13 @@ async def chat(
|
|||||||
async for result in send_event(ChatEvent.STATUS, error_message):
|
async for result in send_event(ChatEvent.STATUS, error_message):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
# Check if the user has disconnected
|
||||||
|
if cancellation_event.is_set():
|
||||||
|
logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.")
|
||||||
|
# Cancel the disconnect monitor task if it is still running
|
||||||
|
await cancel_disconnect_monitor()
|
||||||
|
return
|
||||||
|
|
||||||
## Generate Text Output
|
## Generate Text Output
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
||||||
yield result
|
yield result
|
||||||
@@ -1320,14 +1360,12 @@ async def chat(
|
|||||||
tracer,
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
continue_stream = True
|
|
||||||
async for item in llm_response:
|
async for item in llm_response:
|
||||||
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
||||||
if item is None:
|
if item is None:
|
||||||
continue
|
continue
|
||||||
if not connection_alive or not continue_stream:
|
if cancellation_event.is_set():
|
||||||
# Drain the generator if disconnected but keep processing internally
|
break
|
||||||
continue
|
|
||||||
message = item.response if isinstance(item, ResponseWithThought) else item
|
message = item.response if isinstance(item, ResponseWithThought) else item
|
||||||
if isinstance(item, ResponseWithThought) and item.thought:
|
if isinstance(item, ResponseWithThought) and item.thought:
|
||||||
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
||||||
@@ -1342,11 +1380,12 @@ async def chat(
|
|||||||
async for result in send_event(ChatEvent.MESSAGE, message):
|
async for result in send_event(ChatEvent.MESSAGE, message):
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue_stream = False
|
if not cancellation_event.is_set():
|
||||||
logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
|
logger.warning(f"Error during streaming. Stopping send: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
# Signal end of LLM response after the loop finishes
|
# Signal end of LLM response after the loop finishes
|
||||||
if connection_alive:
|
if not cancellation_event.is_set():
|
||||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
# Send Usage Metadata once llm interactions are complete
|
# Send Usage Metadata once llm interactions are complete
|
||||||
@@ -1357,6 +1396,9 @@ async def chat(
|
|||||||
yield result
|
yield result
|
||||||
logger.debug("Finished streaming response")
|
logger.debug("Finished streaming response")
|
||||||
|
|
||||||
|
# Cancel the disconnect monitor task if it is still running
|
||||||
|
await cancel_disconnect_monitor()
|
||||||
|
|
||||||
## Stream Text Response
|
## Stream Text Response
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -205,11 +206,17 @@ async def execute_information_collection(
|
|||||||
file_filters: List[str] = [],
|
file_filters: List[str] = [],
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
cancellation_event: Optional[asyncio.Event] = None,
|
||||||
):
|
):
|
||||||
current_iteration = 0
|
current_iteration = 0
|
||||||
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
|
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
|
||||||
previous_iterations: List[InformationCollectionIteration] = []
|
previous_iterations: List[InformationCollectionIteration] = []
|
||||||
while current_iteration < MAX_ITERATIONS:
|
while current_iteration < MAX_ITERATIONS:
|
||||||
|
# Check for cancellation at the start of each iteration
|
||||||
|
if cancellation_event and cancellation_event.is_set():
|
||||||
|
logger.debug(f"User {user} disconnected client. Research cancelled.")
|
||||||
|
break
|
||||||
|
|
||||||
online_results: Dict = dict()
|
online_results: Dict = dict()
|
||||||
code_results: Dict = dict()
|
code_results: Dict = dict()
|
||||||
document_results: List[Dict[str, str]] = []
|
document_results: List[Dict[str, str]] = []
|
||||||
|
|||||||
Reference in New Issue
Block a user