diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index f2b46d63..97175253 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -118,6 +118,9 @@ def anthropic_completion_with_backoff( aggregated_response += chunk.delta.text final_message = stream.get_final_message() + # Track raw content of model response to reuse for cache hits in multi-turn chats + raw_content = [item.model_dump() for item in final_message.content] + # Extract all tool calls if tools are enabled if tools: tool_calls = [ @@ -154,7 +157,7 @@ def anthropic_completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) - return ResponseWithThought(text=aggregated_response, thought=thoughts) + return ResponseWithThought(text=aggregated_response, thought=thoughts, raw_content=raw_content) @retry( @@ -289,18 +292,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st # Handle tool call and tool result message types from additional_kwargs message_type = message.additional_kwargs.get("message_type") if message_type == "tool_call": - # Convert tool_call to Anthropic tool_use format - content = [] - for part in message.content: - content.append( - { - "type": "tool_use", - "id": part.pop("id"), - "name": part.pop("name"), - "input": part, - } - ) - message.content = content + pass elif message_type == "tool_result": # Convert tool_result to Anthropic tool_result format content = [] diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 9da8ed1d..4c184944 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -108,7 +108,7 @@ def gemini_completion_with_backoff( gemini_clients[api_key] = client formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt) - response_thoughts: str | None = None + raw_content, response_text, response_thoughts = [], "", None # Configure structured output tools = None @@ -144,6 +144,7 @@ def gemini_completion_with_backoff( try: # Generate the response response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages) + raw_content = [part.model_dump() for part in response.candidates[0].content.parts or []] if response.function_calls: function_calls = [ ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__ @@ -190,7 +191,7 @@ def gemini_completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) - return ResponseWithThought(text=response_text, thought=response_thoughts) + return ResponseWithThought(text=response_text, thought=response_thoughts, raw_content=raw_content) @retry( @@ -376,11 +377,7 @@ def format_messages_for_gemini( # Handle tool call and tool result message types from additional_kwargs message_type = message.additional_kwargs.get("message_type") if message_type == "tool_call": - # Convert tool_call to Gemini function call format - tool_call_msg_content = [] - for part in message.content: - tool_call_msg_content.append(gtypes.Part.from_function_call(name=part["name"], args=part["args"])) - message.content = tool_call_msg_content + pass elif message_type == "tool_result": # Convert tool_result to Gemini function response format # Need to find the corresponding function call from previous messages diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index bb899fb1..81be2270 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -154,6 +154,7 @@ class ResearchIteration: operatorContext: dict | OperatorRun = None, summarizedResult: str = None, warning: str = None, + raw_response: list = None, ): self.query = ToolCall(**query) if isinstance(query, dict) else query self.context = context @@ -162,6 +163,7 @@ class ResearchIteration: self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext self.summarizedResult = summarizedResult self.warning = warning + self.raw_response = raw_response def to_dict(self) -> dict: data = vars(self).copy() @@ -185,7 +187,7 @@ def construct_iteration_history( iteration_history += [ ChatMessageModel( by="khoj", - message=[iteration.query.__dict__], + message=iteration.raw_response or [iteration.query.__dict__], intent=Intent(type="tool_call", query=query), ), ChatMessageModel( @@ -1187,6 +1189,7 @@ class StructuredOutputSupport(int, Enum): class ResponseWithThought: - def __init__(self, text: str = None, thought: str = None): + def __init__(self, text: str = None, thought: str = None, raw_content: list = None): self.text = text self.thought = thought + self.raw_content = raw_content diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d473e30e..227f03bf 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -142,7 +142,7 @@ async def apick_next_tool( try: with timer("Chat actor: Infer information sources to refer", logger): - raw_response = await send_message_to_model_wrapper( + response = await send_message_to_model_wrapper( query="", system_message=function_planning_prompt, chat_history=chat_and_research_history, @@ -165,7 +165,7 @@ async def apick_next_tool( try: # Try parse the response as function call response to infer next tool to use. # TODO: Handle multiple tool calls. - response_text = raw_response.text + response_text = response.text parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0] except Exception as e: # Otherwise assume the model has decided to end the research run and respond to the user. @@ -184,14 +184,11 @@ async def apick_next_tool( if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations: warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different." # Only send client status updates if we'll execute this iteration and model has thoughts to share. - elif send_status_func and not is_none_or_empty(raw_response.thought): - async for event in send_status_func(raw_response.thought): + elif send_status_func and not is_none_or_empty(response.thought): + async for event in send_status_func(response.thought): yield {ChatEvent.STATUS: event} - yield ResearchIteration( - query=parsed_response, - warning=warning, - ) + yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content) async def research(