Add function calling support to Anthropic, Gemini and OpenAI models

Previously these models could use response schema but not tools use
capabilities provided by these AI model APIs.

This change allows chat actors to use the function calling feature to
specify which tools the LLM by these providers can call.

This should help simplify tool definition and structure context in
forms that these LLMs natively understand.
(i.e in tool_call - tool_result ~chatml format).
This commit is contained in:
Debanjum
2025-06-05 19:26:20 -07:00
parent 9607f2e87c
commit b888d5e65e
8 changed files with 230 additions and 129 deletions

View File

@@ -22,7 +22,15 @@ logger = logging.getLogger(__name__)
def anthropic_send_message_to_model(
messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={}
messages,
api_key,
api_base_url,
model,
response_type="text",
response_schema=None,
tools=None,
deepthought=False,
tracer={},
):
"""
Send message to model
@@ -36,6 +44,7 @@ def anthropic_send_message_to_model(
api_base_url=api_base_url,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
tracer=tracer,
)

View File

@@ -1,9 +1,7 @@
import json
import logging
from copy import deepcopy
from textwrap import dedent
from time import perf_counter
from typing import AsyncGenerator, Dict, List, Optional, Type
from typing import AsyncGenerator, Dict, List
import anthropic
from langchain_core.messages.chat import ChatMessage
@@ -18,7 +16,9 @@ from tenacity import (
from khoj.processor.conversation.utils import (
ResponseWithThought,
ToolDefinition,
commit_conversation_trace,
create_tool_definition,
get_image_from_base64,
get_image_from_url,
)
@@ -57,6 +57,7 @@ def anthropic_completion_with_backoff(
max_tokens: int | None = None,
response_type: str = "text",
response_schema: BaseModel | None = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
tracer: dict = {},
) -> str:
@@ -70,9 +71,19 @@ def anthropic_completion_with_backoff(
aggregated_response = ""
final_message = None
model_kwargs = model_kwargs or dict()
if response_schema:
tool = create_anthropic_tool_definition(response_schema=response_schema)
model_kwargs["tools"] = [tool]
# Configure structured output
if tools:
# Convert tools to Anthropic format
model_kwargs["tools"] = [
anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema)
for tool in tools
]
elif response_schema:
tool = create_tool_definition(response_schema)
model_kwargs["tools"] = [
anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema)
]
elif response_type == "json_object" and not (is_reasoning_model(model_name) and deepthought):
# Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking.
formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{"))
@@ -103,7 +114,7 @@ def anthropic_completion_with_backoff(
# Extract first tool call from final message
for item in final_message.content:
if item.type == "tool_use":
aggregated_response = json.dumps(item.input)
aggregated_response = json.dumps([{"name": item.name, "args": item.input}])
break
# Calculate cost of chat
@@ -326,74 +337,5 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
return formatted_messages, system
def create_anthropic_tool_definition(
response_schema: Type[BaseModel],
tool_name: str = None,
tool_description: Optional[str] = None,
) -> anthropic.types.ToolParam:
"""
Converts a response schema BaseModel class into an Anthropic tool definition dictionary.
This format is expected by Anthropic's API when defining tools the model can use.
Args:
response_schema: The Pydantic BaseModel class to convert.
This class defines the response schema for the tool.
tool_name: The name for the Anthropic tool (e.g., "get_weather", "plan_next_step").
tool_description: Optional description for the Anthropic tool.
If None, it attempts to use the Pydantic model's docstring.
If that's also missing, a fallback description is generated.
Returns:
An tool definition for Anthropic's API.
"""
model_schema = response_schema.model_json_schema()
name = tool_name or response_schema.__name__.lower()
description = tool_description
if description is None:
docstring = response_schema.__doc__
if docstring:
description = dedent(docstring).strip()
else:
# Fallback description if no explicit one or docstring is provided
description = f"Tool named '{name}' accepts specified parameters."
# Process properties to inline enums and remove $defs dependency
processed_properties = {}
original_properties = model_schema.get("properties", {})
defs = model_schema.get("$defs", {})
for prop_name, prop_schema in original_properties.items():
current_prop_schema = deepcopy(prop_schema) # Work on a copy
# Check for enums defined directly in the property for simpler direct enum definitions.
if "$ref" in current_prop_schema:
ref_path = current_prop_schema["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
if def_name in defs and "enum" in defs[def_name]:
enum_def = defs[def_name]
current_prop_schema["enum"] = enum_def["enum"]
current_prop_schema["type"] = enum_def.get("type", "string")
if "description" not in current_prop_schema and "description" in enum_def:
current_prop_schema["description"] = enum_def["description"]
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
processed_properties[prop_name] = current_prop_schema
# The input_schema for Anthropic tools is a JSON Schema object.
# Pydantic's model_json_schema() provides most of what's needed.
input_schema = {
"type": "object",
"properties": processed_properties,
}
# Include 'required' fields if specified in the Pydantic model
if "required" in model_schema and model_schema["required"]:
input_schema["required"] = model_schema["required"]
return anthropic.types.ToolParam(name=name, description=description, input_schema=input_schema)
def is_reasoning_model(model_name: str) -> bool:
return any(model_name.startswith(model) for model in REASONING_MODELS)

View File

@@ -28,6 +28,7 @@ def gemini_send_message_to_model(
api_base_url=None,
response_type="text",
response_schema=None,
tools=None,
model_kwargs=None,
deepthought=False,
tracer={},
@@ -37,8 +38,10 @@ def gemini_send_message_to_model(
"""
model_kwargs = {}
if tools:
model_kwargs["tools"] = tools
# Monitor for flakiness in 1.5+ models. This would cause unwanted behavior and terminate response early in 1.5 models.
if response_type == "json_object" and not model.startswith("gemini-1.5"):
elif response_type == "json_object" and not model.startswith("gemini-1.5"):
model_kwargs["response_mime_type"] = "application/json"
if response_schema:
model_kwargs["response_schema"] = response_schema

View File

@@ -1,9 +1,10 @@
import json
import logging
import os
import random
from copy import deepcopy
from time import perf_counter
from typing import AsyncGenerator, AsyncIterator, Dict
from typing import AsyncGenerator, AsyncIterator, Dict, List
import httpx
from google import genai
@@ -22,6 +23,7 @@ from tenacity import (
from khoj.processor.conversation.utils import (
ResponseWithThought,
ToolDefinition,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
@@ -95,7 +97,7 @@ def gemini_completion_with_backoff(
temperature=1.2,
api_key=None,
api_base_url: str = None,
model_kwargs=None,
model_kwargs={},
deepthought=False,
tracer={},
) -> str:
@@ -107,9 +109,12 @@ def gemini_completion_with_backoff(
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
response_thoughts: str | None = None
# format model response schema
# Configure structured output
tools = None
response_schema = None
if model_kwargs and model_kwargs.get("response_schema"):
if model_kwargs.get("tools"):
tools = to_gemini_tools(model_kwargs["tools"])
elif model_kwargs.get("response_schema"):
response_schema = clean_response_schema(model_kwargs["response_schema"])
thinking_config = None
@@ -127,8 +132,9 @@ def gemini_completion_with_backoff(
thinking_config=thinking_config,
max_output_tokens=max_output_tokens,
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_mime_type=model_kwargs.get("response_mime_type", "text/plain"),
response_schema=response_schema,
tools=tools,
seed=seed,
top_p=0.95,
http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
@@ -137,7 +143,14 @@ def gemini_completion_with_backoff(
try:
# Generate the response
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
response_text = response.text
if response.function_calls:
function_calls = [
{"name": function_call.name, "args": function_call.args} for function_call in response.function_calls
]
response_text = json.dumps(function_calls)
else:
# If no function calls, use the text response
response_text = response.text
except gerrors.ClientError as e:
response = None
response_text, _ = handle_gemini_response(e.args)
@@ -404,3 +417,21 @@ def is_reasoning_model(model_name: str) -> bool:
Check if the model is a reasoning model.
"""
return model_name.startswith("gemini-2.5")
def to_gemini_tools(tools: List[ToolDefinition]) -> List[gtypes.ToolDict] | None:
"Transform tool definitions from standard format to Gemini format."
gemini_tools = [
gtypes.ToolDict(
function_declarations=[
gtypes.FunctionDeclarationDict(
name=tool.name,
description=tool.description,
parameters=tool.schema,
)
for tool in tools
]
)
]
return gemini_tools or None

View File

@@ -1,21 +1,21 @@
import logging
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
from openai.lib._pydantic import _ensure_strict_json_schema
from pydantic import BaseModel
from typing import Any, AsyncGenerator, Dict, List, Optional
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
clean_response_schema,
completion_with_backoff,
get_openai_api_json_support,
get_structured_output_support,
to_openai_tools,
)
from khoj.processor.conversation.utils import (
JsonSupport,
OperatorRun,
ResponseWithThought,
StructuredOutputSupport,
ToolDefinition,
generate_chatml_messages_with_context,
messages_to_print,
)
@@ -32,6 +32,7 @@ def send_message_to_model(
model,
response_type="text",
response_schema=None,
tools: list[ToolDefinition] = None,
deepthought=False,
api_base_url=None,
tracer: dict = {},
@@ -40,9 +41,11 @@ def send_message_to_model(
Send message to model
"""
model_kwargs = {}
json_support = get_openai_api_json_support(model, api_base_url)
if response_schema and json_support == JsonSupport.SCHEMA:
model_kwargs: Dict[str, Any] = {}
json_support = get_structured_output_support(model, api_base_url)
if tools and json_support == StructuredOutputSupport.TOOL:
model_kwargs["tools"] = to_openai_tools(tools)
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
# Drop unsupported fields from schema passed to OpenAI APi
cleaned_response_schema = clean_response_schema(response_schema)
model_kwargs["response_format"] = {
@@ -53,7 +56,7 @@ def send_message_to_model(
"strict": True,
},
}
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:
model_kwargs["response_format"] = {"type": response_type}
# Get Response from GPT
@@ -171,30 +174,3 @@ async def converse_openai(
tracer=tracer,
):
yield chunk
def clean_response_schema(schema: BaseModel | dict) -> dict:
"""
Format response schema to be compatible with OpenAI API.
Clean the response schema by removing unsupported fields.
"""
# Normalize schema to OpenAI compatible JSON schema format
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
# Recursively drop unsupported fields from schema passed to OpenAI API
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
fields_to_exclude = ["minItems", "maxItems"]
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
for _, prop_value in schema_json["properties"].items():
if isinstance(prop_value, dict):
# Remove specified fields from direct properties
for field in fields_to_exclude:
prop_value.pop(field, None)
# Recursively remove specified fields from child properties
if "items" in prop_value and isinstance(prop_value["items"], dict):
clean_response_schema(prop_value["items"])
# Return cleaned schema
return schema_json

View File

@@ -1,3 +1,4 @@
import json
import logging
import os
from copy import deepcopy
@@ -9,6 +10,7 @@ from urllib.parse import urlparse
import httpx
import openai
from langchain_core.messages.chat import ChatMessage
from openai.lib._pydantic import _ensure_strict_json_schema
from openai.lib.streaming.chat import (
ChatCompletionStream,
ChatCompletionStreamEvent,
@@ -20,6 +22,7 @@ from openai.types.chat.chat_completion_chunk import (
Choice,
ChoiceDelta,
)
from pydantic import BaseModel
from tenacity import (
before_sleep_log,
retry,
@@ -30,8 +33,9 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
ResponseWithThought,
StructuredOutputSupport,
ToolDefinition,
commit_conversation_trace,
)
from khoj.utils.helpers import (
@@ -131,6 +135,8 @@ def completion_with_backoff(
aggregated_response += chunk.delta
elif chunk.type == "thought.delta":
pass
elif chunk.type == "tool_calls.function.arguments.done":
aggregated_response = json.dumps([{"name": chunk.name, "args": chunk.arguments}])
else:
# Non-streaming chat completion
chunk = client.beta.chat.completions.parse(
@@ -190,6 +196,7 @@ async def chat_completion_with_backoff(
deepthought=False,
model_kwargs: dict = {},
tracer: dict = {},
tools=None,
) -> AsyncGenerator[ResponseWithThought, None]:
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_async_clients.get(client_key)
@@ -258,6 +265,8 @@ async def chat_completion_with_backoff(
read_timeout = 300 if is_local_api(api_base_url) else 60
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
if tools:
model_kwargs["tools"] = tools
aggregated_response = ""
final_chunk = None
@@ -327,16 +336,16 @@ async def chat_completion_with_backoff(
commit_conversation_trace(messages, aggregated_response, tracer)
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
def get_structured_output_support(model_name: str, api_base_url: str = None) -> StructuredOutputSupport:
if model_name.startswith("deepseek-reasoner"):
return JsonSupport.NONE
return StructuredOutputSupport.NONE
if api_base_url:
host = urlparse(api_base_url).hostname
if host and host.endswith(".ai.azure.com"):
return JsonSupport.OBJECT
return StructuredOutputSupport.OBJECT
if host == "api.deepinfra.com":
return JsonSupport.OBJECT
return JsonSupport.SCHEMA
return StructuredOutputSupport.OBJECT
return StructuredOutputSupport.TOOL
def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]:
@@ -708,3 +717,47 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
if isinstance(content_part, dict) and content_part.get("type") == "text":
content_part["text"] += " /no_think"
break
def to_openai_tools(tools: List[ToolDefinition]) -> List[Dict] | None:
"Transform tool definitions from standard format to OpenAI format."
openai_tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": clean_response_schema(tool.schema),
},
}
for tool in tools
]
return openai_tools or None
def clean_response_schema(schema: BaseModel | dict) -> dict:
"""
Format response schema to be compatible with OpenAI API.
Clean the response schema by removing unsupported fields.
"""
# Normalize schema to OpenAI compatible JSON schema format
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
# Recursively drop unsupported fields from schema passed to OpenAI API
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
fields_to_exclude = ["minItems", "maxItems"]
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
for _, prop_value in schema_json["properties"].items():
if isinstance(prop_value, dict):
# Remove specified fields from direct properties
for field in fields_to_exclude:
prop_value.pop(field, None)
# Recursively remove specified fields from child properties
if "items" in prop_value and isinstance(prop_value["items"], dict):
clean_response_schema(prop_value["items"])
# Return cleaned schema
return schema_json

View File

@@ -6,11 +6,13 @@ import mimetypes
import os
import re
import uuid
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from textwrap import dedent
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
import PIL.Image
import pyjson5
@@ -1149,13 +1151,93 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
class JsonSupport(int, Enum):
class StructuredOutputSupport(int, Enum):
NONE = 0
OBJECT = 1
SCHEMA = 2
TOOL = 3
class ResponseWithThought:
def __init__(self, response: str = None, thought: str = None):
self.response = response
self.thought = thought
class ToolDefinition:
def __init__(self, name: str, description: str, schema: dict):
self.name = name
self.description = description
self.schema = schema
def create_tool_definition(
schema: Type[BaseModel],
name: str = None,
description: Optional[str] = None,
) -> ToolDefinition:
"""
Converts a response schema BaseModel class into a normalized tool definition.
A standard AI provider agnostic tool format to specify tools the model can use.
Common logic used across models is kept here. AI provider specific adaptations
should be handled in provider code.
Args:
response_schema: The Pydantic BaseModel class to convert.
This class defines the response schema for the tool.
tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step").
tool_description: Optional description for the AI model tool.
If None, it attempts to use the Pydantic model's docstring.
If that's also missing, a fallback description is generated.
Returns:
A normalized tool definition for AI model APIs.
"""
raw_schema_dict = schema.model_json_schema()
name = name or schema.__name__.lower()
description = description
if description is None:
docstring = schema.__doc__
if docstring:
description = dedent(docstring).strip()
else:
# Fallback description if no explicit one or docstring is provided
description = f"Tool named '{name}' accepts specified parameters."
# Process properties to inline enums and remove $defs dependency
processed_properties = {}
original_properties = raw_schema_dict.get("properties", {})
defs = raw_schema_dict.get("$defs", {})
for prop_name, prop_schema in original_properties.items():
current_prop_schema = deepcopy(prop_schema) # Work on a copy
# Check for enums defined directly in the property for simpler direct enum definitions.
if "$ref" in current_prop_schema:
ref_path = current_prop_schema["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
if def_name in defs and "enum" in defs[def_name]:
enum_def = defs[def_name]
current_prop_schema["enum"] = enum_def["enum"]
current_prop_schema["type"] = enum_def.get("type", "string")
if "description" not in current_prop_schema and "description" in enum_def:
current_prop_schema["description"] = enum_def["description"]
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
processed_properties[prop_name] = current_prop_schema
# Generate the compiled schema dictionary for the tool definition.
compiled_schema = {
"type": "object",
"properties": processed_properties,
# Generate content in the order in which the schema properties were defined
"property_ordering": list(schema.model_fields.keys()),
}
# Include 'required' fields if specified in the Pydantic model
if "required" in raw_schema_dict and raw_schema_dict["required"]:
compiled_schema["required"] = raw_schema_dict["required"]
return ToolDefinition(name=name, description=description, schema=compiled_schema)

View File

@@ -101,6 +101,7 @@ from khoj.processor.conversation.utils import (
OperatorRun,
ResearchIteration,
ResponseWithThought,
ToolDefinition,
clean_json,
clean_mermaidjs,
construct_chat_history,
@@ -1439,6 +1440,7 @@ async def send_message_to_model_wrapper(
system_message: str = "",
response_type: str = "text",
response_schema: BaseModel = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
user: KhojUser = None,
query_images: List[str] = None,
@@ -1506,6 +1508,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
@@ -1517,6 +1520,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
@@ -1528,6 +1532,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,