diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index fe53a73c..14eaadfe 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1465,7 +1465,7 @@ class ConversationAdapters: @require_valid_user async def save_conversation( user: KhojUser, - chat_history: List[ChatMessageModel], + new_messages: List[ChatMessageModel], client_application: ClientApplication = None, conversation_id: str = None, user_message: str = None, @@ -1480,7 +1480,8 @@ class ConversationAdapters: await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() ) - conversation_log = {"chat": [msg.model_dump() for msg in chat_history]} + existing_messages = conversation.messages if conversation else [] + conversation_log = {"chat": [msg.model_dump() for msg in existing_messages + new_messages]} cleaned_conversation_log = clean_object_for_db(conversation_log) if conversation: conversation.conversation_log = cleaned_conversation_log diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 8d903338..7f80459a 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -677,6 +677,34 @@ class Conversation(DbBaseModel): continue return validated_messages + async def pop_message(self, interrupted: bool = False) -> Optional[ChatMessageModel]: + """ + Remove and return the last message from the conversation log, persisting the change to the database. + When interrupted is True, we only drop the last message if it was an interrupted message by khoj. + """ + chat_log = self.conversation_log.get("chat", []) + + if not chat_log: + return None + + last_message = chat_log[-1] + is_interrupted_msg = last_message.get("by") == "khoj" and not last_message.get("message") + # When handling an interruption, only pop if the last message is an empty one by khoj. + if interrupted and not is_interrupted_msg: + return None + + # Pop the last message, save the conversation, and then return the message. + popped_message_dict = chat_log.pop() + await self.asave() + + # Try to validate and return the popped message as a Pydantic model + try: + return ChatMessageModel.model_validate(popped_message_dict) + except ValidationError as e: + logger.warning(f"Popped an invalid message from conversation. The removal has been saved. Error: {e}") + # The invalid message was removed and saved, but we can't return a valid model. + return None + class PublicConversation(DbBaseModel): source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b4dd5e9c..6edd84b2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -435,7 +435,6 @@ async def save_to_conversation_log( q: str, chat_response: str, user: KhojUser, - chat_history: List[ChatMessageModel], user_message_time: str = None, compiled_references: List[Dict[str, Any]] = [], online_results: Dict[str, Any] = {}, @@ -481,22 +480,22 @@ async def save_to_conversation_log( khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram try: - updated_conversation = message_to_log( + new_messages = message_to_log( user_message=q, chat_response=chat_response, user_message_metadata=user_message_metadata, khoj_message_metadata=khoj_message_metadata, - chat_history=chat_history, + chat_history=[], ) except ValidationError as e: - updated_conversation = None + new_messages = None logger.error(f"Error constructing chat history: {e}") db_conversation = None - if updated_conversation: + if new_messages: db_conversation = await ConversationAdapters.save_conversation( user, - updated_conversation, + new_messages, client_application=client_application, conversation_id=conversation_id, user_message=q, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e602ab2c..6d444b92 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -738,6 +738,7 @@ async def event_generator( generated_mermaidjs_diagram: str = None generated_asset_results: Dict = dict() program_execution_context: List[str] = [] + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Create a task to monitor for disconnections disconnect_monitor_task = None @@ -757,7 +758,6 @@ async def event_generator( q, chat_response="", user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -772,6 +772,7 @@ async def event_generator( generated_images=generated_images, raw_generated_files=generated_asset_results, generated_mermaidjs_diagram=generated_mermaidjs_diagram, + user_message_time=user_message_time, tracer=tracer, ) ) @@ -789,7 +790,6 @@ async def event_generator( q, chat_response="", user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -804,6 +804,7 @@ async def event_generator( generated_images=generated_images, raw_generated_files=generated_asset_results, generated_mermaidjs_diagram=generated_mermaidjs_diagram, + user_message_time=user_message_time, tracer=tracer, ) ) @@ -952,18 +953,11 @@ async def event_generator( location = None if city or region or country or country_code: location = LocationData(city=city, region=region, country=country, country_code=country_code) - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") chat_history = conversation.messages # If interrupted message in DB - if ( - conversation - and conversation.messages - and conversation.messages[-1].by == "khoj" - and not conversation.messages[-1].message - ): + if last_message := await conversation.pop_message(interrupted=True): # Populate context from interrupted message - last_message = conversation.messages[-1] online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} compiled_references = [ref.model_dump() for ref in last_message.context or []] @@ -974,8 +968,6 @@ async def event_generator( ] operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] - # Drop the interrupted message from conversation history - chat_history.pop() logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") if conversation_commands == [ConversationCommand.Default]: @@ -1414,7 +1406,6 @@ async def event_generator( q, chat_response=full_response, user=user, - chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results,