mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Track, reuse raw model response for multi-turn conversations
This should avoid the need to reformat the Khoj standardized tool call for cache hits and satisfying ai model api requirements. Previously multi-turn tool use calls to anthropic reasoning models would fail as needed their thoughts to be passed back. Other AI model providers can have other requirements. Passing back the raw response as is should satisfy the default case. Tracking raw response should make it easy to apply any formatting required before sending previous response back, if any ai model provider requires that. Details --- - Raw response content is passed back in ResponseWithThoughts. - Research iteration stores this and puts it into model response ChatMessageModel when constructing iteration history when it is present. Fallback to using parsed tool call when raw response isn't present. - No need to format tool call messages for anthropic models as we're passing the raw response as is.
This commit is contained in:
@@ -118,6 +118,9 @@ def anthropic_completion_with_backoff(
|
||||
aggregated_response += chunk.delta.text
|
||||
final_message = stream.get_final_message()
|
||||
|
||||
# Track raw content of model response to reuse for cache hits in multi-turn chats
|
||||
raw_content = [item.model_dump() for item in final_message.content]
|
||||
|
||||
# Extract all tool calls if tools are enabled
|
||||
if tools:
|
||||
tool_calls = [
|
||||
@@ -154,7 +157,7 @@ def anthropic_completion_with_backoff(
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return ResponseWithThought(text=aggregated_response, thought=thoughts)
|
||||
return ResponseWithThought(text=aggregated_response, thought=thoughts, raw_content=raw_content)
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -289,18 +292,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
# Handle tool call and tool result message types from additional_kwargs
|
||||
message_type = message.additional_kwargs.get("message_type")
|
||||
if message_type == "tool_call":
|
||||
# Convert tool_call to Anthropic tool_use format
|
||||
content = []
|
||||
for part in message.content:
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": part.pop("id"),
|
||||
"name": part.pop("name"),
|
||||
"input": part,
|
||||
}
|
||||
)
|
||||
message.content = content
|
||||
pass
|
||||
elif message_type == "tool_result":
|
||||
# Convert tool_result to Anthropic tool_result format
|
||||
content = []
|
||||
|
||||
@@ -108,7 +108,7 @@ def gemini_completion_with_backoff(
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
|
||||
response_thoughts: str | None = None
|
||||
raw_content, response_text, response_thoughts = [], "", None
|
||||
|
||||
# Configure structured output
|
||||
tools = None
|
||||
@@ -144,6 +144,7 @@ def gemini_completion_with_backoff(
|
||||
try:
|
||||
# Generate the response
|
||||
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
||||
raw_content = [part.model_dump() for part in response.candidates[0].content.parts or []]
|
||||
if response.function_calls:
|
||||
function_calls = [
|
||||
ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__
|
||||
@@ -190,7 +191,7 @@ def gemini_completion_with_backoff(
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return ResponseWithThought(text=response_text, thought=response_thoughts)
|
||||
return ResponseWithThought(text=response_text, thought=response_thoughts, raw_content=raw_content)
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -376,11 +377,7 @@ def format_messages_for_gemini(
|
||||
# Handle tool call and tool result message types from additional_kwargs
|
||||
message_type = message.additional_kwargs.get("message_type")
|
||||
if message_type == "tool_call":
|
||||
# Convert tool_call to Gemini function call format
|
||||
tool_call_msg_content = []
|
||||
for part in message.content:
|
||||
tool_call_msg_content.append(gtypes.Part.from_function_call(name=part["name"], args=part["args"]))
|
||||
message.content = tool_call_msg_content
|
||||
pass
|
||||
elif message_type == "tool_result":
|
||||
# Convert tool_result to Gemini function response format
|
||||
# Need to find the corresponding function call from previous messages
|
||||
|
||||
@@ -154,6 +154,7 @@ class ResearchIteration:
|
||||
operatorContext: dict | OperatorRun = None,
|
||||
summarizedResult: str = None,
|
||||
warning: str = None,
|
||||
raw_response: list = None,
|
||||
):
|
||||
self.query = ToolCall(**query) if isinstance(query, dict) else query
|
||||
self.context = context
|
||||
@@ -162,6 +163,7 @@ class ResearchIteration:
|
||||
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
|
||||
self.summarizedResult = summarizedResult
|
||||
self.warning = warning
|
||||
self.raw_response = raw_response
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
data = vars(self).copy()
|
||||
@@ -185,7 +187,7 @@ def construct_iteration_history(
|
||||
iteration_history += [
|
||||
ChatMessageModel(
|
||||
by="khoj",
|
||||
message=[iteration.query.__dict__],
|
||||
message=iteration.raw_response or [iteration.query.__dict__],
|
||||
intent=Intent(type="tool_call", query=query),
|
||||
),
|
||||
ChatMessageModel(
|
||||
@@ -1187,6 +1189,7 @@ class StructuredOutputSupport(int, Enum):
|
||||
|
||||
|
||||
class ResponseWithThought:
|
||||
def __init__(self, text: str = None, thought: str = None):
|
||||
def __init__(self, text: str = None, thought: str = None, raw_content: list = None):
|
||||
self.text = text
|
||||
self.thought = thought
|
||||
self.raw_content = raw_content
|
||||
|
||||
@@ -142,7 +142,7 @@ async def apick_next_tool(
|
||||
|
||||
try:
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
response = await send_message_to_model_wrapper(
|
||||
query="",
|
||||
system_message=function_planning_prompt,
|
||||
chat_history=chat_and_research_history,
|
||||
@@ -165,7 +165,7 @@ async def apick_next_tool(
|
||||
try:
|
||||
# Try parse the response as function call response to infer next tool to use.
|
||||
# TODO: Handle multiple tool calls.
|
||||
response_text = raw_response.text
|
||||
response_text = response.text
|
||||
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
|
||||
except Exception as e:
|
||||
# Otherwise assume the model has decided to end the research run and respond to the user.
|
||||
@@ -184,14 +184,11 @@ async def apick_next_tool(
|
||||
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
|
||||
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
|
||||
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
|
||||
elif send_status_func and not is_none_or_empty(raw_response.thought):
|
||||
async for event in send_status_func(raw_response.thought):
|
||||
elif send_status_func and not is_none_or_empty(response.thought):
|
||||
async for event in send_status_func(response.thought):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
yield ResearchIteration(
|
||||
query=parsed_response,
|
||||
warning=warning,
|
||||
)
|
||||
yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content)
|
||||
|
||||
|
||||
async def research(
|
||||
|
||||
Reference in New Issue
Block a user