mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Store event delimiter in chat event enum for reuse
This commit is contained in:
@@ -385,6 +385,7 @@ class ChatEvent(Enum):
|
||||
USAGE = "usage"
|
||||
END_RESPONSE = "end_response"
|
||||
INTERRUPT = "interrupt"
|
||||
END_EVENT = "␃🔚␗"
|
||||
|
||||
|
||||
def message_to_log(
|
||||
|
||||
@@ -44,7 +44,7 @@ async def operate_environment(
|
||||
query_files: str = None, # TODO: Handle query files
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
abort_message: Optional[str] = "␃🔚␗",
|
||||
abort_message: Optional[str] = ChatEvent.END_EVENT.value,
|
||||
tracer: dict = {},
|
||||
):
|
||||
response, user_input_message = None, None
|
||||
|
||||
@@ -704,7 +704,6 @@ async def event_generator(
|
||||
train_of_thought = []
|
||||
cancellation_event = asyncio.Event()
|
||||
child_interrupt_queue: asyncio.Queue = asyncio.Queue(maxsize=10)
|
||||
event_delimiter = "␃🔚␗"
|
||||
|
||||
tracer: dict = {
|
||||
"mid": turn_id,
|
||||
@@ -791,7 +790,7 @@ async def event_generator(
|
||||
|
||||
# Check if any interrupt query is received
|
||||
if interrupt_query := get_message_from_queue(parent_interrupt_queue):
|
||||
if interrupt_query == event_delimiter:
|
||||
if interrupt_query == ChatEvent.END_EVENT.value:
|
||||
cancellation_event.set()
|
||||
logger.debug(f"Chat cancelled by user {user} via interrupt queue.")
|
||||
else:
|
||||
@@ -872,7 +871,7 @@ async def event_generator(
|
||||
)
|
||||
finally:
|
||||
if not cancellation_event.is_set():
|
||||
yield event_delimiter
|
||||
yield ChatEvent.END_EVENT.value
|
||||
# Cancel the disconnect monitor task if it is still running
|
||||
if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
|
||||
await cancel_disconnect_monitor()
|
||||
@@ -1044,7 +1043,7 @@ async def event_generator(
|
||||
tracer=tracer,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=child_interrupt_queue,
|
||||
abort_message=event_delimiter,
|
||||
abort_message=ChatEvent.END_EVENT.value,
|
||||
):
|
||||
if isinstance(research_result, ResearchIteration):
|
||||
if research_result.summarizedResult:
|
||||
@@ -1510,8 +1509,7 @@ async def chat_ws(
|
||||
if data.get("type") == "interrupt":
|
||||
if current_task and not current_task.done():
|
||||
# Send interrupt signal to the ongoing task
|
||||
abort_message = "␃🔚␗"
|
||||
await interrupt_queue.put(data.get("query") or abort_message)
|
||||
await interrupt_queue.put(data.get("query") or ChatEvent.END_EVENT.value)
|
||||
logger.info(
|
||||
f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}"
|
||||
)
|
||||
|
||||
@@ -2617,7 +2617,6 @@ class MessageProcessor:
|
||||
|
||||
async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict[str, Any]:
|
||||
processor = MessageProcessor()
|
||||
event_delimiter = "␃🔚␗"
|
||||
buffer = ""
|
||||
|
||||
async for chunk in response_iterator:
|
||||
@@ -2625,9 +2624,9 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
||||
buffer += chunk
|
||||
|
||||
# Once the buffer contains a complete event
|
||||
while event_delimiter in buffer:
|
||||
while ChatEvent.END_EVENT.value in buffer:
|
||||
# Extract the event from the buffer
|
||||
event, buffer = buffer.split(event_delimiter, 1)
|
||||
event, buffer = buffer.split(ChatEvent.END_EVENT.value, 1)
|
||||
# Process the event
|
||||
if event:
|
||||
processor.process_message_chunk(event)
|
||||
|
||||
@@ -224,7 +224,7 @@ async def research(
|
||||
query_files: str = None,
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
abort_message: str = "␃🔚␗",
|
||||
abort_message: str = ChatEvent.END_EVENT.value,
|
||||
):
|
||||
max_document_searches = 7
|
||||
max_online_searches = 3
|
||||
|
||||
Reference in New Issue
Block a user