Track, reuse raw model response for multi-turn conversations

This should avoid the need to reformat the Khoj standardized tool call
for cache hits and satisfying ai model api requirements.

Previously multi-turn tool use calls to anthropic reasoning models
would fail as needed their thoughts to be passed back. Other AI model
providers can have other requirements.

Passing back the raw response as is should satisfy the default case.

Tracking raw response should make it easy to apply any formatting
required before sending previous response back, if any ai model
provider requires that.

Details
---
- Raw response content is passed back in ResponseWithThoughts.
- Research iteration stores this and puts it into model response
  ChatMessageModel when constructing iteration history when it is
  present.
  Fallback to using parsed tool call when raw response isn't present.
- No need to format tool call messages for anthropic models as we're
  passing the raw response as is.
This commit is contained in:
Debanjum
2025-06-19 15:21:58 -07:00
parent 7cd496ac19
commit c2ab75efef
4 changed files with 19 additions and 30 deletions

View File

@@ -118,6 +118,9 @@ def anthropic_completion_with_backoff(
aggregated_response += chunk.delta.text
final_message = stream.get_final_message()
# Track raw content of model response to reuse for cache hits in multi-turn chats
raw_content = [item.model_dump() for item in final_message.content]
# Extract all tool calls if tools are enabled
if tools:
tool_calls = [
@@ -154,7 +157,7 @@ def anthropic_completion_with_backoff(
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
return ResponseWithThought(text=aggregated_response, thought=thoughts)
return ResponseWithThought(text=aggregated_response, thought=thoughts, raw_content=raw_content)
@retry(
@@ -289,18 +292,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
# Handle tool call and tool result message types from additional_kwargs
message_type = message.additional_kwargs.get("message_type")
if message_type == "tool_call":
# Convert tool_call to Anthropic tool_use format
content = []
for part in message.content:
content.append(
{
"type": "tool_use",
"id": part.pop("id"),
"name": part.pop("name"),
"input": part,
}
)
message.content = content
pass
elif message_type == "tool_result":
# Convert tool_result to Anthropic tool_result format
content = []

View File

@@ -108,7 +108,7 @@ def gemini_completion_with_backoff(
gemini_clients[api_key] = client
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
response_thoughts: str | None = None
raw_content, response_text, response_thoughts = [], "", None
# Configure structured output
tools = None
@@ -144,6 +144,7 @@ def gemini_completion_with_backoff(
try:
# Generate the response
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
raw_content = [part.model_dump() for part in response.candidates[0].content.parts or []]
if response.function_calls:
function_calls = [
ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__
@@ -190,7 +191,7 @@ def gemini_completion_with_backoff(
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)
return ResponseWithThought(text=response_text, thought=response_thoughts)
return ResponseWithThought(text=response_text, thought=response_thoughts, raw_content=raw_content)
@retry(
@@ -376,11 +377,7 @@ def format_messages_for_gemini(
# Handle tool call and tool result message types from additional_kwargs
message_type = message.additional_kwargs.get("message_type")
if message_type == "tool_call":
# Convert tool_call to Gemini function call format
tool_call_msg_content = []
for part in message.content:
tool_call_msg_content.append(gtypes.Part.from_function_call(name=part["name"], args=part["args"]))
message.content = tool_call_msg_content
pass
elif message_type == "tool_result":
# Convert tool_result to Gemini function response format
# Need to find the corresponding function call from previous messages

View File

@@ -154,6 +154,7 @@ class ResearchIteration:
operatorContext: dict | OperatorRun = None,
summarizedResult: str = None,
warning: str = None,
raw_response: list = None,
):
self.query = ToolCall(**query) if isinstance(query, dict) else query
self.context = context
@@ -162,6 +163,7 @@ class ResearchIteration:
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
self.summarizedResult = summarizedResult
self.warning = warning
self.raw_response = raw_response
def to_dict(self) -> dict:
data = vars(self).copy()
@@ -185,7 +187,7 @@ def construct_iteration_history(
iteration_history += [
ChatMessageModel(
by="khoj",
message=[iteration.query.__dict__],
message=iteration.raw_response or [iteration.query.__dict__],
intent=Intent(type="tool_call", query=query),
),
ChatMessageModel(
@@ -1187,6 +1189,7 @@ class StructuredOutputSupport(int, Enum):
class ResponseWithThought:
def __init__(self, text: str = None, thought: str = None):
def __init__(self, text: str = None, thought: str = None, raw_content: list = None):
self.text = text
self.thought = thought
self.raw_content = raw_content

View File

@@ -142,7 +142,7 @@ async def apick_next_tool(
try:
with timer("Chat actor: Infer information sources to refer", logger):
raw_response = await send_message_to_model_wrapper(
response = await send_message_to_model_wrapper(
query="",
system_message=function_planning_prompt,
chat_history=chat_and_research_history,
@@ -165,7 +165,7 @@ async def apick_next_tool(
try:
# Try parse the response as function call response to infer next tool to use.
# TODO: Handle multiple tool calls.
response_text = raw_response.text
response_text = response.text
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
except Exception as e:
# Otherwise assume the model has decided to end the research run and respond to the user.
@@ -184,14 +184,11 @@ async def apick_next_tool(
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
elif send_status_func and not is_none_or_empty(raw_response.thought):
async for event in send_status_func(raw_response.thought):
elif send_status_func and not is_none_or_empty(response.thought):
async for event in send_status_func(response.thought):
yield {ChatEvent.STATUS: event}
yield ResearchIteration(
query=parsed_response,
warning=warning,
)
yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content)
async def research(