mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Use tool calls to enforce response schema for anthropic models
- Converts response schema into a anthropic tool call definition. - Works with simple enums without needing to rely on $defs, $refs as unsupported by Anthropic API - Do not force specific tool use as not supported with deep thought This puts anthropic models on parity with openai, gemini models for response schema following. Reduces need for complex json response parsing on khoj end.
This commit is contained in:
@@ -117,7 +117,7 @@ def extract_questions_anthropic(
|
|||||||
|
|
||||||
|
|
||||||
def anthropic_send_message_to_model(
|
def anthropic_send_message_to_model(
|
||||||
messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={}
|
messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={}
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
@@ -130,6 +130,7 @@ def anthropic_send_message_to_model(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
response_schema=response_schema,
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
|
from textwrap import dedent
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import AsyncGenerator, Dict, List
|
from typing import AsyncGenerator, Dict, List, Optional, Type
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from langchain_core.messages.chat import ChatMessage
|
from langchain_core.messages.chat import ChatMessage
|
||||||
|
from pydantic import BaseModel
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
retry,
|
retry,
|
||||||
@@ -52,6 +56,7 @@ def anthropic_completion_with_backoff(
|
|||||||
model_kwargs: dict | None = None,
|
model_kwargs: dict | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
|
response_schema: BaseModel | None = None,
|
||||||
deepthought: bool = False,
|
deepthought: bool = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -63,13 +68,16 @@ def anthropic_completion_with_backoff(
|
|||||||
formatted_messages, system = format_messages_for_anthropic(messages, system_prompt)
|
formatted_messages, system = format_messages_for_anthropic(messages, system_prompt)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
if response_type == "json_object" and not deepthought:
|
final_message = None
|
||||||
|
model_kwargs = model_kwargs or dict()
|
||||||
|
if response_schema:
|
||||||
|
tool = create_anthropic_tool_definition(response_schema=response_schema)
|
||||||
|
model_kwargs["tools"] = [tool]
|
||||||
|
elif response_type == "json_object" and not deepthought:
|
||||||
# Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking.
|
# Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking.
|
||||||
formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{"))
|
formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{"))
|
||||||
aggregated_response += "{"
|
aggregated_response += "{"
|
||||||
|
|
||||||
final_message = None
|
|
||||||
model_kwargs = model_kwargs or dict()
|
|
||||||
if system:
|
if system:
|
||||||
model_kwargs["system"] = system
|
model_kwargs["system"] = system
|
||||||
|
|
||||||
@@ -92,6 +100,12 @@ def anthropic_completion_with_backoff(
|
|||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
final_message = stream.get_final_message()
|
final_message = stream.get_final_message()
|
||||||
|
|
||||||
|
# Extract first tool call from final message
|
||||||
|
for item in final_message.content:
|
||||||
|
if item.type == "tool_use":
|
||||||
|
aggregated_response = json.dumps(item.input)
|
||||||
|
break
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = final_message.usage.input_tokens
|
input_tokens = final_message.usage.input_tokens
|
||||||
output_tokens = final_message.usage.output_tokens
|
output_tokens = final_message.usage.output_tokens
|
||||||
@@ -305,5 +319,74 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
|||||||
return formatted_messages, system
|
return formatted_messages, system
|
||||||
|
|
||||||
|
|
||||||
|
def create_anthropic_tool_definition(
|
||||||
|
response_schema: Type[BaseModel],
|
||||||
|
tool_name: str = None,
|
||||||
|
tool_description: Optional[str] = None,
|
||||||
|
) -> anthropic.types.ToolParam:
|
||||||
|
"""
|
||||||
|
Converts a response schema BaseModel class into an Anthropic tool definition dictionary.
|
||||||
|
|
||||||
|
This format is expected by Anthropic's API when defining tools the model can use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_schema: The Pydantic BaseModel class to convert.
|
||||||
|
This class defines the response schema for the tool.
|
||||||
|
tool_name: The name for the Anthropic tool (e.g., "get_weather", "plan_next_step").
|
||||||
|
tool_description: Optional description for the Anthropic tool.
|
||||||
|
If None, it attempts to use the Pydantic model's docstring.
|
||||||
|
If that's also missing, a fallback description is generated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An tool definition for Anthropic's API.
|
||||||
|
"""
|
||||||
|
model_schema = response_schema.model_json_schema()
|
||||||
|
|
||||||
|
name = tool_name or response_schema.__name__.lower()
|
||||||
|
description = tool_description
|
||||||
|
if description is None:
|
||||||
|
docstring = response_schema.__doc__
|
||||||
|
if docstring:
|
||||||
|
description = dedent(docstring).strip()
|
||||||
|
else:
|
||||||
|
# Fallback description if no explicit one or docstring is provided
|
||||||
|
description = f"Tool named '{name}' accepts specified parameters."
|
||||||
|
|
||||||
|
# Process properties to inline enums and remove $defs dependency
|
||||||
|
processed_properties = {}
|
||||||
|
original_properties = model_schema.get("properties", {})
|
||||||
|
defs = model_schema.get("$defs", {})
|
||||||
|
|
||||||
|
for prop_name, prop_schema in original_properties.items():
|
||||||
|
current_prop_schema = deepcopy(prop_schema) # Work on a copy
|
||||||
|
# Check for enums defined directly in the property for simpler direct enum definitions.
|
||||||
|
if "$ref" in current_prop_schema:
|
||||||
|
ref_path = current_prop_schema["$ref"]
|
||||||
|
if ref_path.startswith("#/$defs/"):
|
||||||
|
def_name = ref_path.split("/")[-1]
|
||||||
|
if def_name in defs and "enum" in defs[def_name]:
|
||||||
|
enum_def = defs[def_name]
|
||||||
|
current_prop_schema["enum"] = enum_def["enum"]
|
||||||
|
current_prop_schema["type"] = enum_def.get("type", "string")
|
||||||
|
if "description" not in current_prop_schema and "description" in enum_def:
|
||||||
|
current_prop_schema["description"] = enum_def["description"]
|
||||||
|
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
|
||||||
|
|
||||||
|
processed_properties[prop_name] = current_prop_schema
|
||||||
|
|
||||||
|
# The input_schema for Anthropic tools is a JSON Schema object.
|
||||||
|
# Pydantic's model_json_schema() provides most of what's needed.
|
||||||
|
input_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": processed_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include 'required' fields if specified in the Pydantic model
|
||||||
|
if "required" in model_schema and model_schema["required"]:
|
||||||
|
input_schema["required"] = model_schema["required"]
|
||||||
|
|
||||||
|
return anthropic.types.ToolParam(name=name, description=description, input_schema=input_schema)
|
||||||
|
|
||||||
|
|
||||||
def is_reasoning_model(model_name: str) -> bool:
|
def is_reasoning_model(model_name: str) -> bool:
|
||||||
return any(model_name.startswith(model) for model in REASONING_MODELS)
|
return any(model_name.startswith(model) for model in REASONING_MODELS)
|
||||||
|
|||||||
@@ -1232,6 +1232,7 @@ async def send_message_to_model_wrapper(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
response_schema=response_schema,
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
|||||||
@@ -71,6 +71,11 @@ class PlanningResponse(BaseModel):
|
|||||||
|
|
||||||
# Create and return a customized response model with the enum
|
# Create and return a customized response model with the enum
|
||||||
class PlanningResponseWithTool(PlanningResponse):
|
class PlanningResponseWithTool(PlanningResponse):
|
||||||
|
"""
|
||||||
|
Use the scratchpad to reason about which tool to use next and the query to send to the tool.
|
||||||
|
Pick tool from provided options and your query to send to the tool.
|
||||||
|
"""
|
||||||
|
|
||||||
tool: tool_enum = Field(..., description="Name of the tool to use")
|
tool: tool_enum = Field(..., description="Name of the tool to use")
|
||||||
query: str = Field(..., description="Detailed query for the selected tool")
|
query: str = Field(..., description="Detailed query for the selected tool")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user