Make callers only share new messages to append to chat logs

- Chat history is retrieved and updated with new messages just before
  write. This is to reduce chance of message loss due to conflicting
  writes making last to save to conversation win conflict.
- This was problematic artifact of old code. Removing it should reduce
  conflict surface area.
- Interrupts and live chat could hit this issue due to different reasons
This commit is contained in:
Debanjum
2025-06-27 00:27:47 -07:00
parent eaed0c839e
commit 7b7b1830b7
4 changed files with 40 additions and 21 deletions

View File

@@ -1465,7 +1465,7 @@ class ConversationAdapters:
@require_valid_user @require_valid_user
async def save_conversation( async def save_conversation(
user: KhojUser, user: KhojUser,
chat_history: List[ChatMessageModel], new_messages: List[ChatMessageModel],
client_application: ClientApplication = None, client_application: ClientApplication = None,
conversation_id: str = None, conversation_id: str = None,
user_message: str = None, user_message: str = None,
@@ -1480,7 +1480,8 @@ class ConversationAdapters:
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
) )
conversation_log = {"chat": [msg.model_dump() for msg in chat_history]} existing_messages = conversation.messages if conversation else []
conversation_log = {"chat": [msg.model_dump() for msg in existing_messages + new_messages]}
cleaned_conversation_log = clean_object_for_db(conversation_log) cleaned_conversation_log = clean_object_for_db(conversation_log)
if conversation: if conversation:
conversation.conversation_log = cleaned_conversation_log conversation.conversation_log = cleaned_conversation_log

View File

@@ -677,6 +677,34 @@ class Conversation(DbBaseModel):
continue continue
return validated_messages return validated_messages
async def pop_message(self, interrupted: bool = False) -> Optional[ChatMessageModel]:
"""
Remove and return the last message from the conversation log, persisting the change to the database.
When interrupted is True, we only drop the last message if it was an interrupted message by khoj.
"""
chat_log = self.conversation_log.get("chat", [])
if not chat_log:
return None
last_message = chat_log[-1]
is_interrupted_msg = last_message.get("by") == "khoj" and not last_message.get("message")
# When handling an interruption, only pop if the last message is an empty one by khoj.
if interrupted and not is_interrupted_msg:
return None
# Pop the last message, save the conversation, and then return the message.
popped_message_dict = chat_log.pop()
await self.asave()
# Try to validate and return the popped message as a Pydantic model
try:
return ChatMessageModel.model_validate(popped_message_dict)
except ValidationError as e:
logger.warning(f"Popped an invalid message from conversation. The removal has been saved. Error: {e}")
# The invalid message was removed and saved, but we can't return a valid model.
return None
class PublicConversation(DbBaseModel): class PublicConversation(DbBaseModel):
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE) source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)

View File

@@ -435,7 +435,6 @@ async def save_to_conversation_log(
q: str, q: str,
chat_response: str, chat_response: str,
user: KhojUser, user: KhojUser,
chat_history: List[ChatMessageModel],
user_message_time: str = None, user_message_time: str = None,
compiled_references: List[Dict[str, Any]] = [], compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {}, online_results: Dict[str, Any] = {},
@@ -481,22 +480,22 @@ async def save_to_conversation_log(
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
try: try:
updated_conversation = message_to_log( new_messages = message_to_log(
user_message=q, user_message=q,
chat_response=chat_response, chat_response=chat_response,
user_message_metadata=user_message_metadata, user_message_metadata=user_message_metadata,
khoj_message_metadata=khoj_message_metadata, khoj_message_metadata=khoj_message_metadata,
chat_history=chat_history, chat_history=[],
) )
except ValidationError as e: except ValidationError as e:
updated_conversation = None new_messages = None
logger.error(f"Error constructing chat history: {e}") logger.error(f"Error constructing chat history: {e}")
db_conversation = None db_conversation = None
if updated_conversation: if new_messages:
db_conversation = await ConversationAdapters.save_conversation( db_conversation = await ConversationAdapters.save_conversation(
user, user,
updated_conversation, new_messages,
client_application=client_application, client_application=client_application,
conversation_id=conversation_id, conversation_id=conversation_id,
user_message=q, user_message=q,

View File

@@ -738,6 +738,7 @@ async def event_generator(
generated_mermaidjs_diagram: str = None generated_mermaidjs_diagram: str = None
generated_asset_results: Dict = dict() generated_asset_results: Dict = dict()
program_execution_context: List[str] = [] program_execution_context: List[str] = []
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Create a task to monitor for disconnections # Create a task to monitor for disconnections
disconnect_monitor_task = None disconnect_monitor_task = None
@@ -757,7 +758,6 @@ async def event_generator(
q, q,
chat_response="", chat_response="",
user=user, user=user,
chat_history=chat_history,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,
@@ -772,6 +772,7 @@ async def event_generator(
generated_images=generated_images, generated_images=generated_images,
raw_generated_files=generated_asset_results, raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram, generated_mermaidjs_diagram=generated_mermaidjs_diagram,
user_message_time=user_message_time,
tracer=tracer, tracer=tracer,
) )
) )
@@ -789,7 +790,6 @@ async def event_generator(
q, q,
chat_response="", chat_response="",
user=user, user=user,
chat_history=chat_history,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,
@@ -804,6 +804,7 @@ async def event_generator(
generated_images=generated_images, generated_images=generated_images,
raw_generated_files=generated_asset_results, raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram, generated_mermaidjs_diagram=generated_mermaidjs_diagram,
user_message_time=user_message_time,
tracer=tracer, tracer=tracer,
) )
) )
@@ -952,18 +953,11 @@ async def event_generator(
location = None location = None
if city or region or country or country_code: if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code) location = LocationData(city=city, region=region, country=country, country_code=country_code)
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_history = conversation.messages chat_history = conversation.messages
# If interrupted message in DB # If interrupted message in DB
if ( if last_message := await conversation.pop_message(interrupted=True):
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
# Populate context from interrupted message # Populate context from interrupted message
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
compiled_references = [ref.model_dump() for ref in last_message.context or []] compiled_references = [ref.model_dump() for ref in last_message.context or []]
@@ -974,8 +968,6 @@ async def event_generator(
] ]
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
# Drop the interrupted message from conversation history
chat_history.pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
if conversation_commands == [ConversationCommand.Default]: if conversation_commands == [ConversationCommand.Default]:
@@ -1414,7 +1406,6 @@ async def event_generator(
q, q,
chat_response=full_response, chat_response=full_response,
user=user, user=user,
chat_history=chat_history,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,