From 65d9ad6cb25d285c0807ec2aaf381ac667f97ed4 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 3 Jun 2025 17:36:00 -0700 Subject: [PATCH] Use tool calls to enforce response schema for anthropic models - Converts response schema into a anthropic tool call definition. - Works with simple enums without needing to rely on $defs, $refs as unsupported by Anthropic API - Do not force specific tool use as not supported with deep thought This puts anthropic models on parity with openai, gemini models for response schema following. Reduces need for complex json response parsing on khoj end. --- .../conversation/anthropic/anthropic_chat.py | 3 +- .../processor/conversation/anthropic/utils.py | 91 ++++++++++++++++++- src/khoj/routers/helpers.py | 1 + src/khoj/routers/research.py | 5 + 4 files changed, 95 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 655a5baa..3d34573e 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -117,7 +117,7 @@ def extract_questions_anthropic( def anthropic_send_message_to_model( - messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={} + messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={} ): """ Send message to model @@ -130,6 +130,7 @@ def anthropic_send_message_to_model( api_key=api_key, api_base_url=api_base_url, response_type=response_type, + response_schema=response_schema, deepthought=deepthought, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 02658edc..ed525eeb 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -1,9 +1,13 @@ +import json import logging +from copy import deepcopy +from textwrap import dedent from time import perf_counter -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, Dict, List, Optional, Type import anthropic from langchain_core.messages.chat import ChatMessage +from pydantic import BaseModel from tenacity import ( before_sleep_log, retry, @@ -52,6 +56,7 @@ def anthropic_completion_with_backoff( model_kwargs: dict | None = None, max_tokens: int | None = None, response_type: str = "text", + response_schema: BaseModel | None = None, deepthought: bool = False, tracer: dict = {}, ) -> str: @@ -63,13 +68,16 @@ def anthropic_completion_with_backoff( formatted_messages, system = format_messages_for_anthropic(messages, system_prompt) aggregated_response = "" - if response_type == "json_object" and not deepthought: + final_message = None + model_kwargs = model_kwargs or dict() + if response_schema: + tool = create_anthropic_tool_definition(response_schema=response_schema) + model_kwargs["tools"] = [tool] + elif response_type == "json_object" and not deepthought: # Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking. formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{")) aggregated_response += "{" - final_message = None - model_kwargs = model_kwargs or dict() if system: model_kwargs["system"] = system @@ -92,6 +100,12 @@ def anthropic_completion_with_backoff( aggregated_response += text final_message = stream.get_final_message() + # Extract first tool call from final message + for item in final_message.content: + if item.type == "tool_use": + aggregated_response = json.dumps(item.input) + break + # Calculate cost of chat input_tokens = final_message.usage.input_tokens output_tokens = final_message.usage.output_tokens @@ -305,5 +319,74 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st return formatted_messages, system +def create_anthropic_tool_definition( + response_schema: Type[BaseModel], + tool_name: str = None, + tool_description: Optional[str] = None, +) -> anthropic.types.ToolParam: + """ + Converts a response schema BaseModel class into an Anthropic tool definition dictionary. + + This format is expected by Anthropic's API when defining tools the model can use. + + Args: + response_schema: The Pydantic BaseModel class to convert. + This class defines the response schema for the tool. + tool_name: The name for the Anthropic tool (e.g., "get_weather", "plan_next_step"). + tool_description: Optional description for the Anthropic tool. + If None, it attempts to use the Pydantic model's docstring. + If that's also missing, a fallback description is generated. + + Returns: + An tool definition for Anthropic's API. + """ + model_schema = response_schema.model_json_schema() + + name = tool_name or response_schema.__name__.lower() + description = tool_description + if description is None: + docstring = response_schema.__doc__ + if docstring: + description = dedent(docstring).strip() + else: + # Fallback description if no explicit one or docstring is provided + description = f"Tool named '{name}' accepts specified parameters." + + # Process properties to inline enums and remove $defs dependency + processed_properties = {} + original_properties = model_schema.get("properties", {}) + defs = model_schema.get("$defs", {}) + + for prop_name, prop_schema in original_properties.items(): + current_prop_schema = deepcopy(prop_schema) # Work on a copy + # Check for enums defined directly in the property for simpler direct enum definitions. + if "$ref" in current_prop_schema: + ref_path = current_prop_schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + if def_name in defs and "enum" in defs[def_name]: + enum_def = defs[def_name] + current_prop_schema["enum"] = enum_def["enum"] + current_prop_schema["type"] = enum_def.get("type", "string") + if "description" not in current_prop_schema and "description" in enum_def: + current_prop_schema["description"] = enum_def["description"] + del current_prop_schema["$ref"] # Remove the $ref as it's been inlined + + processed_properties[prop_name] = current_prop_schema + + # The input_schema for Anthropic tools is a JSON Schema object. + # Pydantic's model_json_schema() provides most of what's needed. + input_schema = { + "type": "object", + "properties": processed_properties, + } + + # Include 'required' fields if specified in the Pydantic model + if "required" in model_schema and model_schema["required"]: + input_schema["required"] = model_schema["required"] + + return anthropic.types.ToolParam(name=name, description=description, input_schema=input_schema) + + def is_reasoning_model(model_name: str) -> bool: return any(model_name.startswith(model) for model in REASONING_MODELS) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6d0d0064..392f5025 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1232,6 +1232,7 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, deepthought=deepthought, api_base_url=api_base_url, tracer=tracer, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 83ec141a..494a01b6 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -71,6 +71,11 @@ class PlanningResponse(BaseModel): # Create and return a customized response model with the enum class PlanningResponseWithTool(PlanningResponse): + """ + Use the scratchpad to reason about which tool to use next and the query to send to the tool. + Pick tool from provided options and your query to send to the tool. + """ + tool: tool_enum = Field(..., description="Name of the tool to use") query: str = Field(..., description="Detailed query for the selected tool")