diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index bf72121c..15023388 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -682,6 +682,7 @@ async def chat( timezone = body.timezone raw_images = body.images raw_query_files = body.files + interrupt_flag = body.interrupt async def event_generator(q: str, images: list[str]): start_time = time.perf_counter() @@ -920,6 +921,34 @@ async def chat( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") meta_log = conversation.conversation_log + # If interrupt flag is set, wait for the previous turn to be saved before proceeding + if interrupt_flag: + max_wait_time = 20.0 # seconds + wait_interval = 0.3 # seconds + wait_start = wait_current = time.time() + while wait_current - wait_start < max_wait_time: + # Refresh conversation to check if interrupted message saved to DB + conversation = await ConversationAdapters.aget_conversation_by_user( + user, + client_application=request.user.client_app, + conversation_id=conversation_id, + ) + if ( + conversation + and conversation.messages + and conversation.messages[-1].by == "khoj" + and not conversation.messages[-1].message + ): + logger.info(f"Detected interrupted message save to conversation {conversation_id}.") + break + await asyncio.sleep(wait_interval) + wait_current = time.time() + + if wait_current - wait_start >= max_wait_time: + logger.warning( + f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context." + ) + # If interrupted message in DB if ( conversation diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index fc637ea5..e0248a66 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -168,6 +168,7 @@ class ChatRequestBody(BaseModel): images: Optional[list[str]] = None files: Optional[list[FileAttachment]] = [] create_new: Optional[bool] = False + interrupt: Optional[bool] = False class Entry: