From 70ad78990afde8578801887f626ac5b49ae7eccc Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 20 Mar 2024 15:04:14 +0530 Subject: [PATCH] Use a common method for sending a generic message to the client from the server in the ws connection --- src/khoj/routers/api_chat.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index dc240a78..6c218f83 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -269,6 +269,16 @@ async def websocket_endpoint( connection_alive = False logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + async def send_message(message: str): + nonlocal connection_alive + if not connection_alive: + return + try: + await websocket.send_text(message) + except ConnectionClosedOK: + connection_alive = False + logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + user: KhojUser = websocket.user.object conversation = await ConversationAdapters.aget_conversation_by_user( user, client_application=websocket.user.client_app, conversation_id=conversation_id @@ -346,7 +356,7 @@ async def websocket_endpoint( if ConversationCommand.Online in conversation_commands: try: - await send_status_update("Searching the web for relevant information 🌐") + await send_status_update("**Operation**: Searching the web for relevant information...") online_results = await search_online(defiltered_query, meta_log, location) online_searches = "".join([f"{query}" for query in online_results.keys()]) await send_status_update(f"**Online searches**: {online_searches}") @@ -363,6 +373,7 @@ async def websocket_endpoint( api="chat", metadata={"conversation_command": conversation_commands[0].value}, ) + await send_status_update("**Operation**: Augmenting your query and generating a superb image...") intent_type = "text-to-image" image, status_code, improved_image_prompt, image_url = await text_to_image( q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results @@ -420,29 +431,19 @@ async def websocket_endpoint( ) iterator = AsyncIteratorWrapper(llm_response) - if connection_alive: - try: - await websocket.send_text("start_llm_response") - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + await send_message("start_llm_response") async for item in iterator: if item is None: break if connection_alive: try: - await websocket.send_text(f"{item}") + await send_message(f"{item}") except ConnectionClosedOK: connection_alive = False logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - if connection_alive: - try: - await websocket.send_text("end_llm_response") - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + await send_message("end_llm_response") @api_chat.get("", response_class=Response)