From 786b06bb3fcd0fa25a68a08b16755a2788240071 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 2 Jul 2025 22:08:30 -0700 Subject: [PATCH] Handle failed llm calls, message idempotency to improve retry success - Deepcopy messages before formatting message for Anthropic to allow idempotency so retry on failure behaves as expected - Handle failed calls to pick next tools to pass failure warning and continue next research iteration. Previously if API call to pick next failed, the research run would crash - Add null response check for when Gemini models fail to respond --- .../processor/conversation/anthropic/utils.py | 8 ++++--- .../processor/conversation/google/utils.py | 8 ++++++- src/khoj/processor/conversation/utils.py | 23 ++++++++++++++++++- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 2842a270..fab7ed54 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -1,5 +1,6 @@ import json import logging +from copy import deepcopy from time import perf_counter from typing import AsyncGenerator, Dict, List @@ -271,13 +272,14 @@ async def anthropic_chat_completion_with_backoff( commit_conversation_trace(messages, aggregated_response, tracer) -def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None): +def format_messages_for_anthropic(raw_messages: list[ChatMessage], system_prompt: str = None): """ Format messages for Anthropic """ # Extract system prompt system_prompt = system_prompt or "" - for message in messages.copy(): + messages = deepcopy(raw_messages) + for message in messages: if message.role == "system": if isinstance(message.content, list): system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"]) @@ -295,7 +297,6 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st elif len(messages) > 1 and messages[0].role == "assistant": messages = messages[1:] - # Convert image urls to base64 encoded images in Anthropic message format for message in messages: # Handle tool call and tool result message types from additional_kwargs message_type = message.additional_kwargs.get("message_type") @@ -313,6 +314,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st } ) message.content = content + # Convert image urls to base64 encoded images in Anthropic message format elif isinstance(message.content, list): content = [] # Sort the content. Anthropic models prefer that text comes after images. diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 4c184944..8105c014 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -144,7 +144,13 @@ def gemini_completion_with_backoff( try: # Generate the response response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages) - raw_content = [part.model_dump() for part in response.candidates[0].content.parts or []] + if ( + not response.candidates + or not response.candidates[0].content + or response.candidates[0].content.parts is None + ): + raise ValueError(f"Failed to get response from model.") + raw_content = [part.model_dump() for part in response.candidates[0].content.parts] if response.function_calls: function_calls = [ ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__ diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 81be2270..36390389 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import PIL.Image import pyjson5 @@ -184,6 +184,16 @@ def construct_iteration_history( iteration_history.append(ChatMessageModel(by="you", message=query_message_content)) for iteration in previous_iterations: + if not iteration.query: + iteration_history.append( + ChatMessageModel( + by="you", + message=iteration.summarizedResult + or iteration.warning + or "Please specify what you want to do next.", + ) + ) + continue iteration_history += [ ChatMessageModel( by="khoj", @@ -326,6 +336,17 @@ def construct_tool_chat_history( ), } for iteration in previous_iterations: + if not iteration.query: + chat_history.append( + ChatMessageModel( + by="you", + message=iteration.summarizedResult + or iteration.warning + or "Please specify what you want to do next.", + ) + ) + continue + # If a tool is provided use the inferred query extractor for that tool if available # If no tool is provided, use inferred query extractor for the tool used in the iteration # Fallback to base extractor if the tool does not have an inferred query extractor