mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Handle failed llm calls, message idempotency to improve retry success
- Deepcopy messages before formatting message for Anthropic to allow idempotency so retry on failure behaves as expected - Handle failed calls to pick next tools to pass failure warning and continue next research iteration. Previously if API call to pick next failed, the research run would crash - Add null response check for when Gemini models fail to respond
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
|
||||
@@ -271,13 +272,14 @@ async def anthropic_chat_completion_with_backoff(
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
|
||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
||||
def format_messages_for_anthropic(raw_messages: list[ChatMessage], system_prompt: str = None):
|
||||
"""
|
||||
Format messages for Anthropic
|
||||
"""
|
||||
# Extract system prompt
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
messages = deepcopy(raw_messages)
|
||||
for message in messages:
|
||||
if message.role == "system":
|
||||
if isinstance(message.content, list):
|
||||
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||
@@ -295,7 +297,6 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
elif len(messages) > 1 and messages[0].role == "assistant":
|
||||
messages = messages[1:]
|
||||
|
||||
# Convert image urls to base64 encoded images in Anthropic message format
|
||||
for message in messages:
|
||||
# Handle tool call and tool result message types from additional_kwargs
|
||||
message_type = message.additional_kwargs.get("message_type")
|
||||
@@ -313,6 +314,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
}
|
||||
)
|
||||
message.content = content
|
||||
# Convert image urls to base64 encoded images in Anthropic message format
|
||||
elif isinstance(message.content, list):
|
||||
content = []
|
||||
# Sort the content. Anthropic models prefer that text comes after images.
|
||||
|
||||
@@ -144,7 +144,13 @@ def gemini_completion_with_backoff(
|
||||
try:
|
||||
# Generate the response
|
||||
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
||||
raw_content = [part.model_dump() for part in response.candidates[0].content.parts or []]
|
||||
if (
|
||||
not response.candidates
|
||||
or not response.candidates[0].content
|
||||
or response.candidates[0].content.parts is None
|
||||
):
|
||||
raise ValueError(f"Failed to get response from model.")
|
||||
raw_content = [part.model_dump() for part in response.candidates[0].content.parts]
|
||||
if response.function_calls:
|
||||
function_calls = [
|
||||
ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__
|
||||
|
||||
@@ -10,7 +10,7 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import PIL.Image
|
||||
import pyjson5
|
||||
@@ -184,6 +184,16 @@ def construct_iteration_history(
|
||||
iteration_history.append(ChatMessageModel(by="you", message=query_message_content))
|
||||
|
||||
for iteration in previous_iterations:
|
||||
if not iteration.query:
|
||||
iteration_history.append(
|
||||
ChatMessageModel(
|
||||
by="you",
|
||||
message=iteration.summarizedResult
|
||||
or iteration.warning
|
||||
or "Please specify what you want to do next.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
iteration_history += [
|
||||
ChatMessageModel(
|
||||
by="khoj",
|
||||
@@ -326,6 +336,17 @@ def construct_tool_chat_history(
|
||||
),
|
||||
}
|
||||
for iteration in previous_iterations:
|
||||
if not iteration.query:
|
||||
chat_history.append(
|
||||
ChatMessageModel(
|
||||
by="you",
|
||||
message=iteration.summarizedResult
|
||||
or iteration.warning
|
||||
or "Please specify what you want to do next.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# If a tool is provided use the inferred query extractor for that tool if available
|
||||
# If no tool is provided, use inferred query extractor for the tool used in the iteration
|
||||
# Fallback to base extractor if the tool does not have an inferred query extractor
|
||||
|
||||
Reference in New Issue
Block a user