From 7b7b1830b7fc79911711f6affd1bb15fe4671a9c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 27 Jun 2025 00:27:47 -0700 Subject: [PATCH] Make callers only share new messages to append to chat logs - Chat history is retrieved and updated with new messages just before write. This is to reduce chance of message loss due to conflicting writes making last to save to conversation win conflict. - This was problematic artifact of old code. Removing it should reduce conflict surface area. - Interrupts and live chat could hit this issue due to different reasons --- src/khoj/database/adapters/__init__.py | 5 +++-- src/khoj/database/models/__init__.py | 28 ++++++++++++++++++++++++ src/khoj/processor/conversation/utils.py | 11 +++++----- src/khoj/routers/api_chat.py | 17 ++++---------- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index fe53a73c..14eaadfe 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1465,7 +1465,7 @@ class ConversationAdapters: @require_valid_user async def save_conversation( user: KhojUser, - chat_history: List[ChatMessageModel], + new_messages: List[ChatMessageModel], client_application: ClientApplication = None, conversation_id: str = None, user_message: str = None, @@ -1480,7 +1480,8 @@ class ConversationAdapters: await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() ) - conversation_log = {"chat": [msg.model_dump() for msg in chat_history]} + existing_messages = conversation.messages if conversation else [] + conversation_log = {"chat": [msg.model_dump() for msg in existing_messages + new_messages]} cleaned_conversation_log = clean_object_for_db(conversation_log) if conversation: conversation.conversation_log = cleaned_conversation_log diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 8d903338..7f80459a 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -677,6 +677,34 @@ class Conversation(DbBaseModel): continue return validated_messages + async def pop_message(self, interrupted: bool = False) -> Optional[ChatMessageModel]: + """ + Remove and return the last message from the conversation log, persisting the change to the database. + When interrupted is True, we only drop the last message if it was an interrupted message by khoj. + """ + chat_log = self.conversation_log.get("chat", []) + + if not chat_log: + return None + + last_message = chat_log[-1] + is_interrupted_msg = last_message.get("by") == "khoj" and not last_message.get("message") + # When handling an interruption, only pop if the last message is an empty one by khoj. + if interrupted and not is_interrupted_msg: + return None + + # Pop the last message, save the conversation, and then return the message. + popped_message_dict = chat_log.pop() + await self.asave() + + # Try to validate and return the popped message as a Pydantic model + try: + return ChatMessageModel.model_validate(popped_message_dict) + except ValidationError as e: + logger.warning(f"Popped an invalid message from conversation. The removal has been saved. Error: {e}") + # The invalid message was removed and saved, but we can't return a valid model. + return None + class PublicConversation(DbBaseModel): source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b4dd5e9c..6edd84b2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -435,7 +435,6 @@ async def save_to_conversation_log( q: str, chat_response: str, user: KhojUser, - chat_history: List[ChatMessageModel], user_message_time: str = None, compiled_references: List[Dict[str, Any]] = [], online_results: Dict[str, Any] = {}, @@ -481,22 +480,22 @@ async def save_to_conversation_log( khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram try: - updated_conversation = message_to_log( + new_messages = message_to_log( user_message=q, chat_response=chat_response, user_message_metadata=user_message_metadata, khoj_message_metadata=khoj_message_metadata, - chat_history=chat_history, + chat_history=[], ) except ValidationError as e: - updated_conversation = None + new_messages = None logger.error(f"Error constructing chat history: {e}") db_conversation = None - if updated_conversation: + if new_messages: db_conversation = await ConversationAdapters.save_conversation( user, - updated_conversation, + new_messages, client_application=client_application, conversation_id=conversation_id, user_message=q, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e602ab2c..6d444b92 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -738,6 +738,7 @@ async def event_generator( generated_mermaidjs_diagram: str = None generated_asset_results: Dict = dict() program_execution_context: List[str] = [] + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Create a task to monitor for disconnections disconnect_monitor_task = None @@ -757,7 +758,6 @@ async def event_generator( q, chat_response="", user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -772,6 +772,7 @@ async def event_generator( generated_images=generated_images, raw_generated_files=generated_asset_results, generated_mermaidjs_diagram=generated_mermaidjs_diagram, + user_message_time=user_message_time, tracer=tracer, ) ) @@ -789,7 +790,6 @@ async def event_generator( q, chat_response="", user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -804,6 +804,7 @@ async def event_generator( generated_images=generated_images, raw_generated_files=generated_asset_results, generated_mermaidjs_diagram=generated_mermaidjs_diagram, + user_message_time=user_message_time, tracer=tracer, ) ) @@ -952,18 +953,11 @@ async def event_generator( location = None if city or region or country or country_code: location = LocationData(city=city, region=region, country=country, country_code=country_code) - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") chat_history = conversation.messages # If interrupted message in DB - if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message - ): + if last_message := await conversation.pop_message(interrupted=True): # Populate context from interrupted message - last_message = conversation.messages[-1] online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} compiled_references = [ref.model_dump() for ref in last_message.context or []] @@ -974,8 +968,6 @@ async def event_generator( ] operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] - # Drop the interrupted message from conversation history - chat_history.pop() logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") if conversation_commands == [ConversationCommand.Default]: @@ -1414,7 +1406,6 @@ async def event_generator( q, chat_response=full_response, user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results,