Make callers only share new messages to append to chat logs

- Chat history is retrieved and updated with new messages just before
  write. This is to reduce chance of message loss due to conflicting
  writes making last to save to conversation win conflict.
- This was problematic artifact of old code. Removing it should reduce
  conflict surface area.
- Interrupts and live chat could hit this issue due to different reasons
This commit is contained in:
Debanjum
2025-06-27 00:27:47 -07:00
parent eaed0c839e
commit 7b7b1830b7
4 changed files with 40 additions and 21 deletions

View File

@@ -1465,7 +1465,7 @@ class ConversationAdapters:
@require_valid_user
async def save_conversation(
user: KhojUser,
chat_history: List[ChatMessageModel],
new_messages: List[ChatMessageModel],
client_application: ClientApplication = None,
conversation_id: str = None,
user_message: str = None,
@@ -1480,7 +1480,8 @@ class ConversationAdapters:
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
)
conversation_log = {"chat": [msg.model_dump() for msg in chat_history]}
existing_messages = conversation.messages if conversation else []
conversation_log = {"chat": [msg.model_dump() for msg in existing_messages + new_messages]}
cleaned_conversation_log = clean_object_for_db(conversation_log)
if conversation:
conversation.conversation_log = cleaned_conversation_log

View File

@@ -677,6 +677,34 @@ class Conversation(DbBaseModel):
continue
return validated_messages
async def pop_message(self, interrupted: bool = False) -> Optional[ChatMessageModel]:
"""
Remove and return the last message from the conversation log, persisting the change to the database.
When interrupted is True, we only drop the last message if it was an interrupted message by khoj.
"""
chat_log = self.conversation_log.get("chat", [])
if not chat_log:
return None
last_message = chat_log[-1]
is_interrupted_msg = last_message.get("by") == "khoj" and not last_message.get("message")
# When handling an interruption, only pop if the last message is an empty one by khoj.
if interrupted and not is_interrupted_msg:
return None
# Pop the last message, save the conversation, and then return the message.
popped_message_dict = chat_log.pop()
await self.asave()
# Try to validate and return the popped message as a Pydantic model
try:
return ChatMessageModel.model_validate(popped_message_dict)
except ValidationError as e:
logger.warning(f"Popped an invalid message from conversation. The removal has been saved. Error: {e}")
# The invalid message was removed and saved, but we can't return a valid model.
return None
class PublicConversation(DbBaseModel):
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)

View File

@@ -435,7 +435,6 @@ async def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
chat_history: List[ChatMessageModel],
user_message_time: str = None,
compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {},
@@ -481,22 +480,22 @@ async def save_to_conversation_log(
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
try:
updated_conversation = message_to_log(
new_messages = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata=user_message_metadata,
khoj_message_metadata=khoj_message_metadata,
chat_history=chat_history,
chat_history=[],
)
except ValidationError as e:
updated_conversation = None
new_messages = None
logger.error(f"Error constructing chat history: {e}")
db_conversation = None
if updated_conversation:
if new_messages:
db_conversation = await ConversationAdapters.save_conversation(
user,
updated_conversation,
new_messages,
client_application=client_application,
conversation_id=conversation_id,
user_message=q,

View File

@@ -738,6 +738,7 @@ async def event_generator(
generated_mermaidjs_diagram: str = None
generated_asset_results: Dict = dict()
program_execution_context: List[str] = []
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Create a task to monitor for disconnections
disconnect_monitor_task = None
@@ -757,7 +758,6 @@ async def event_generator(
q,
chat_response="",
user=user,
chat_history=chat_history,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
@@ -772,6 +772,7 @@ async def event_generator(
generated_images=generated_images,
raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
user_message_time=user_message_time,
tracer=tracer,
)
)
@@ -789,7 +790,6 @@ async def event_generator(
q,
chat_response="",
user=user,
chat_history=chat_history,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
@@ -804,6 +804,7 @@ async def event_generator(
generated_images=generated_images,
raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
user_message_time=user_message_time,
tracer=tracer,
)
)
@@ -952,18 +953,11 @@ async def event_generator(
location = None
if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code)
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_history = conversation.messages
# If interrupted message in DB
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
if last_message := await conversation.pop_message(interrupted=True):
# Populate context from interrupted message
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
compiled_references = [ref.model_dump() for ref in last_message.context or []]
@@ -974,8 +968,6 @@ async def event_generator(
]
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
# Drop the interrupted message from conversation history
chat_history.pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
if conversation_commands == [ConversationCommand.Default]:
@@ -1414,7 +1406,6 @@ async def event_generator(
q,
chat_response=full_response,
user=user,
chat_history=chat_history,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,