From c2ab75efef55097b00961361c725e2725cc3b1ad Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 19 Jun 2025 15:21:58 -0700 Subject: [PATCH] Track, reuse raw model response for multi-turn conversations This should avoid the need to reformat the Khoj standardized tool call for cache hits and satisfying ai model api requirements. Previously multi-turn tool use calls to anthropic reasoning models would fail as needed their thoughts to be passed back. Other AI model providers can have other requirements. Passing back the raw response as is should satisfy the default case. Tracking raw response should make it easy to apply any formatting required before sending previous response back, if any ai model provider requires that. Details --- - Raw response content is passed back in ResponseWithThoughts. - Research iteration stores this and puts it into model response ChatMessageModel when constructing iteration history when it is present. Fallback to using parsed tool call when raw response isn't present. - No need to format tool call messages for anthropic models as we're passing the raw response as is. --- .../processor/conversation/anthropic/utils.py | 18 +++++------------- .../processor/conversation/google/utils.py | 11 ++++------- src/khoj/processor/conversation/utils.py | 7 +++++-- src/khoj/routers/research.py | 13 +++++-------- 4 files changed, 19 insertions(+), 30 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index f2b46d63..97175253 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -118,6 +118,9 @@ def anthropic_completion_with_backoff( aggregated_response += chunk.delta.text final_message = stream.get_final_message() + # Track raw content of model response to reuse for cache hits in multi-turn chats + raw_content = [item.model_dump() for item in final_message.content] + # Extract all tool calls if tools are enabled if tools: tool_calls = [ @@ -154,7 +157,7 @@ def anthropic_completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) - return ResponseWithThought(text=aggregated_response, thought=thoughts) + return ResponseWithThought(text=aggregated_response, thought=thoughts, raw_content=raw_content) @retry( @@ -289,18 +292,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st # Handle tool call and tool result message types from additional_kwargs message_type = message.additional_kwargs.get("message_type") if message_type == "tool_call": - # Convert tool_call to Anthropic tool_use format - content = [] - for part in message.content: - content.append( - { - "type": "tool_use", - "id": part.pop("id"), - "name": part.pop("name"), - "input": part, - } - ) - message.content = content + pass elif message_type == "tool_result": # Convert tool_result to Anthropic tool_result format content = [] diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 9da8ed1d..4c184944 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -108,7 +108,7 @@ def gemini_completion_with_backoff( gemini_clients[api_key] = client formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt) - response_thoughts: str | None = None + raw_content, response_text, response_thoughts = [], "", None # Configure structured output tools = None @@ -144,6 +144,7 @@ def gemini_completion_with_backoff( try: # Generate the response response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages) + raw_content = [part.model_dump() for part in response.candidates[0].content.parts or []] if response.function_calls: function_calls = [ ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__ @@ -190,7 +191,7 @@ def gemini_completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) - return ResponseWithThought(text=response_text, thought=response_thoughts) + return ResponseWithThought(text=response_text, thought=response_thoughts, raw_content=raw_content) @retry( @@ -376,11 +377,7 @@ def format_messages_for_gemini( # Handle tool call and tool result message types from additional_kwargs message_type = message.additional_kwargs.get("message_type") if message_type == "tool_call": - # Convert tool_call to Gemini function call format - tool_call_msg_content = [] - for part in message.content: - tool_call_msg_content.append(gtypes.Part.from_function_call(name=part["name"], args=part["args"])) - message.content = tool_call_msg_content + pass elif message_type == "tool_result": # Convert tool_result to Gemini function response format # Need to find the corresponding function call from previous messages diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index bb899fb1..81be2270 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -154,6 +154,7 @@ class ResearchIteration: operatorContext: dict | OperatorRun = None, summarizedResult: str = None, warning: str = None, + raw_response: list = None, ): self.query = ToolCall(**query) if isinstance(query, dict) else query self.context = context @@ -162,6 +163,7 @@ class ResearchIteration: self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext self.summarizedResult = summarizedResult self.warning = warning + self.raw_response = raw_response def to_dict(self) -> dict: data = vars(self).copy() @@ -185,7 +187,7 @@ def construct_iteration_history( iteration_history += [ ChatMessageModel( by="khoj", - message=[iteration.query.__dict__], + message=iteration.raw_response or [iteration.query.__dict__], intent=Intent(type="tool_call", query=query), ), ChatMessageModel( @@ -1187,6 +1189,7 @@ class StructuredOutputSupport(int, Enum): class ResponseWithThought: - def __init__(self, text: str = None, thought: str = None): + def __init__(self, text: str = None, thought: str = None, raw_content: list = None): self.text = text self.thought = thought + self.raw_content = raw_content diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d473e30e..227f03bf 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -142,7 +142,7 @@ async def apick_next_tool( try: with timer("Chat actor: Infer information sources to refer", logger): - raw_response = await send_message_to_model_wrapper( + response = await send_message_to_model_wrapper( query="", system_message=function_planning_prompt, chat_history=chat_and_research_history, @@ -165,7 +165,7 @@ async def apick_next_tool( try: # Try parse the response as function call response to infer next tool to use. # TODO: Handle multiple tool calls. - response_text = raw_response.text + response_text = response.text parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0] except Exception as e: # Otherwise assume the model has decided to end the research run and respond to the user. @@ -184,14 +184,11 @@ async def apick_next_tool( if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations: warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different." # Only send client status updates if we'll execute this iteration and model has thoughts to share. - elif send_status_func and not is_none_or_empty(raw_response.thought): - async for event in send_status_func(raw_response.thought): + elif send_status_func and not is_none_or_empty(response.thought): + async for event in send_status_func(response.thought): yield {ChatEvent.STATUS: event} - yield ResearchIteration( - query=parsed_response, - warning=warning, - ) + yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content) async def research(