mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Add function calling support to Anthropic, Gemini and OpenAI models
Previously these models could use response schema but not tools use capabilities provided by these AI model APIs. This change allows chat actors to use the function calling feature to specify which tools the LLM by these providers can call. This should help simplify tool definition and structure context in forms that these LLMs natively understand. (i.e in tool_call - tool_result ~chatml format).
This commit is contained in:
@@ -22,7 +22,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(
|
||||
messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={}
|
||||
messages,
|
||||
api_key,
|
||||
api_base_url,
|
||||
model,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
tools=None,
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
@@ -36,6 +44,7 @@ def anthropic_send_message_to_model(
|
||||
api_base_url=api_base_url,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from textwrap import dedent
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Type
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
|
||||
import anthropic
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
@@ -18,7 +16,9 @@ from tenacity import (
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
ToolDefinition,
|
||||
commit_conversation_trace,
|
||||
create_tool_definition,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
)
|
||||
@@ -57,6 +57,7 @@ def anthropic_completion_with_backoff(
|
||||
max_tokens: int | None = None,
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel | None = None,
|
||||
tools: List[ToolDefinition] = None,
|
||||
deepthought: bool = False,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
@@ -70,9 +71,19 @@ def anthropic_completion_with_backoff(
|
||||
aggregated_response = ""
|
||||
final_message = None
|
||||
model_kwargs = model_kwargs or dict()
|
||||
if response_schema:
|
||||
tool = create_anthropic_tool_definition(response_schema=response_schema)
|
||||
model_kwargs["tools"] = [tool]
|
||||
|
||||
# Configure structured output
|
||||
if tools:
|
||||
# Convert tools to Anthropic format
|
||||
model_kwargs["tools"] = [
|
||||
anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema)
|
||||
for tool in tools
|
||||
]
|
||||
elif response_schema:
|
||||
tool = create_tool_definition(response_schema)
|
||||
model_kwargs["tools"] = [
|
||||
anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema)
|
||||
]
|
||||
elif response_type == "json_object" and not (is_reasoning_model(model_name) and deepthought):
|
||||
# Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking.
|
||||
formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{"))
|
||||
@@ -103,7 +114,7 @@ def anthropic_completion_with_backoff(
|
||||
# Extract first tool call from final message
|
||||
for item in final_message.content:
|
||||
if item.type == "tool_use":
|
||||
aggregated_response = json.dumps(item.input)
|
||||
aggregated_response = json.dumps([{"name": item.name, "args": item.input}])
|
||||
break
|
||||
|
||||
# Calculate cost of chat
|
||||
@@ -326,74 +337,5 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
return formatted_messages, system
|
||||
|
||||
|
||||
def create_anthropic_tool_definition(
|
||||
response_schema: Type[BaseModel],
|
||||
tool_name: str = None,
|
||||
tool_description: Optional[str] = None,
|
||||
) -> anthropic.types.ToolParam:
|
||||
"""
|
||||
Converts a response schema BaseModel class into an Anthropic tool definition dictionary.
|
||||
|
||||
This format is expected by Anthropic's API when defining tools the model can use.
|
||||
|
||||
Args:
|
||||
response_schema: The Pydantic BaseModel class to convert.
|
||||
This class defines the response schema for the tool.
|
||||
tool_name: The name for the Anthropic tool (e.g., "get_weather", "plan_next_step").
|
||||
tool_description: Optional description for the Anthropic tool.
|
||||
If None, it attempts to use the Pydantic model's docstring.
|
||||
If that's also missing, a fallback description is generated.
|
||||
|
||||
Returns:
|
||||
An tool definition for Anthropic's API.
|
||||
"""
|
||||
model_schema = response_schema.model_json_schema()
|
||||
|
||||
name = tool_name or response_schema.__name__.lower()
|
||||
description = tool_description
|
||||
if description is None:
|
||||
docstring = response_schema.__doc__
|
||||
if docstring:
|
||||
description = dedent(docstring).strip()
|
||||
else:
|
||||
# Fallback description if no explicit one or docstring is provided
|
||||
description = f"Tool named '{name}' accepts specified parameters."
|
||||
|
||||
# Process properties to inline enums and remove $defs dependency
|
||||
processed_properties = {}
|
||||
original_properties = model_schema.get("properties", {})
|
||||
defs = model_schema.get("$defs", {})
|
||||
|
||||
for prop_name, prop_schema in original_properties.items():
|
||||
current_prop_schema = deepcopy(prop_schema) # Work on a copy
|
||||
# Check for enums defined directly in the property for simpler direct enum definitions.
|
||||
if "$ref" in current_prop_schema:
|
||||
ref_path = current_prop_schema["$ref"]
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_name = ref_path.split("/")[-1]
|
||||
if def_name in defs and "enum" in defs[def_name]:
|
||||
enum_def = defs[def_name]
|
||||
current_prop_schema["enum"] = enum_def["enum"]
|
||||
current_prop_schema["type"] = enum_def.get("type", "string")
|
||||
if "description" not in current_prop_schema and "description" in enum_def:
|
||||
current_prop_schema["description"] = enum_def["description"]
|
||||
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
|
||||
|
||||
processed_properties[prop_name] = current_prop_schema
|
||||
|
||||
# The input_schema for Anthropic tools is a JSON Schema object.
|
||||
# Pydantic's model_json_schema() provides most of what's needed.
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": processed_properties,
|
||||
}
|
||||
|
||||
# Include 'required' fields if specified in the Pydantic model
|
||||
if "required" in model_schema and model_schema["required"]:
|
||||
input_schema["required"] = model_schema["required"]
|
||||
|
||||
return anthropic.types.ToolParam(name=name, description=description, input_schema=input_schema)
|
||||
|
||||
|
||||
def is_reasoning_model(model_name: str) -> bool:
|
||||
return any(model_name.startswith(model) for model in REASONING_MODELS)
|
||||
|
||||
@@ -28,6 +28,7 @@ def gemini_send_message_to_model(
|
||||
api_base_url=None,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
tools=None,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
@@ -37,8 +38,10 @@ def gemini_send_message_to_model(
|
||||
"""
|
||||
model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
model_kwargs["tools"] = tools
|
||||
# Monitor for flakiness in 1.5+ models. This would cause unwanted behavior and terminate response early in 1.5 models.
|
||||
if response_type == "json_object" and not model.startswith("gemini-1.5"):
|
||||
elif response_type == "json_object" and not model.startswith("gemini-1.5"):
|
||||
model_kwargs["response_mime_type"] = "application/json"
|
||||
if response_schema:
|
||||
model_kwargs["response_schema"] = response_schema
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List
|
||||
|
||||
import httpx
|
||||
from google import genai
|
||||
@@ -22,6 +23,7 @@ from tenacity import (
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
ToolDefinition,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
@@ -95,7 +97,7 @@ def gemini_completion_with_backoff(
|
||||
temperature=1.2,
|
||||
api_key=None,
|
||||
api_base_url: str = None,
|
||||
model_kwargs=None,
|
||||
model_kwargs={},
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
) -> str:
|
||||
@@ -107,9 +109,12 @@ def gemini_completion_with_backoff(
|
||||
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
|
||||
response_thoughts: str | None = None
|
||||
|
||||
# format model response schema
|
||||
# Configure structured output
|
||||
tools = None
|
||||
response_schema = None
|
||||
if model_kwargs and model_kwargs.get("response_schema"):
|
||||
if model_kwargs.get("tools"):
|
||||
tools = to_gemini_tools(model_kwargs["tools"])
|
||||
elif model_kwargs.get("response_schema"):
|
||||
response_schema = clean_response_schema(model_kwargs["response_schema"])
|
||||
|
||||
thinking_config = None
|
||||
@@ -127,8 +132,9 @@ def gemini_completion_with_backoff(
|
||||
thinking_config=thinking_config,
|
||||
max_output_tokens=max_output_tokens,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain"),
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
seed=seed,
|
||||
top_p=0.95,
|
||||
http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
|
||||
@@ -137,7 +143,14 @@ def gemini_completion_with_backoff(
|
||||
try:
|
||||
# Generate the response
|
||||
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
||||
response_text = response.text
|
||||
if response.function_calls:
|
||||
function_calls = [
|
||||
{"name": function_call.name, "args": function_call.args} for function_call in response.function_calls
|
||||
]
|
||||
response_text = json.dumps(function_calls)
|
||||
else:
|
||||
# If no function calls, use the text response
|
||||
response_text = response.text
|
||||
except gerrors.ClientError as e:
|
||||
response = None
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
@@ -404,3 +417,21 @@ def is_reasoning_model(model_name: str) -> bool:
|
||||
Check if the model is a reasoning model.
|
||||
"""
|
||||
return model_name.startswith("gemini-2.5")
|
||||
|
||||
|
||||
def to_gemini_tools(tools: List[ToolDefinition]) -> List[gtypes.ToolDict] | None:
|
||||
"Transform tool definitions from standard format to Gemini format."
|
||||
gemini_tools = [
|
||||
gtypes.ToolDict(
|
||||
function_declarations=[
|
||||
gtypes.FunctionDeclarationDict(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.schema,
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
return gemini_tools or None
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from openai.lib._pydantic import _ensure_strict_json_schema
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
clean_response_schema,
|
||||
completion_with_backoff,
|
||||
get_openai_api_json_support,
|
||||
get_structured_output_support,
|
||||
to_openai_tools,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
StructuredOutputSupport,
|
||||
ToolDefinition,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
@@ -32,6 +32,7 @@ def send_message_to_model(
|
||||
model,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
tools: list[ToolDefinition] = None,
|
||||
deepthought=False,
|
||||
api_base_url=None,
|
||||
tracer: dict = {},
|
||||
@@ -40,9 +41,11 @@ def send_message_to_model(
|
||||
Send message to model
|
||||
"""
|
||||
|
||||
model_kwargs = {}
|
||||
json_support = get_openai_api_json_support(model, api_base_url)
|
||||
if response_schema and json_support == JsonSupport.SCHEMA:
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
json_support = get_structured_output_support(model, api_base_url)
|
||||
if tools and json_support == StructuredOutputSupport.TOOL:
|
||||
model_kwargs["tools"] = to_openai_tools(tools)
|
||||
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
||||
# Drop unsupported fields from schema passed to OpenAI APi
|
||||
cleaned_response_schema = clean_response_schema(response_schema)
|
||||
model_kwargs["response_format"] = {
|
||||
@@ -53,7 +56,7 @@ def send_message_to_model(
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
|
||||
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:
|
||||
model_kwargs["response_format"] = {"type": response_type}
|
||||
|
||||
# Get Response from GPT
|
||||
@@ -171,30 +174,3 @@ async def converse_openai(
|
||||
tracer=tracer,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||
"""
|
||||
Format response schema to be compatible with OpenAI API.
|
||||
|
||||
Clean the response schema by removing unsupported fields.
|
||||
"""
|
||||
# Normalize schema to OpenAI compatible JSON schema format
|
||||
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
|
||||
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
|
||||
|
||||
# Recursively drop unsupported fields from schema passed to OpenAI API
|
||||
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
||||
fields_to_exclude = ["minItems", "maxItems"]
|
||||
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
|
||||
for _, prop_value in schema_json["properties"].items():
|
||||
if isinstance(prop_value, dict):
|
||||
# Remove specified fields from direct properties
|
||||
for field in fields_to_exclude:
|
||||
prop_value.pop(field, None)
|
||||
# Recursively remove specified fields from child properties
|
||||
if "items" in prop_value and isinstance(prop_value["items"], dict):
|
||||
clean_response_schema(prop_value["items"])
|
||||
|
||||
# Return cleaned schema
|
||||
return schema_json
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
@@ -9,6 +10,7 @@ from urllib.parse import urlparse
|
||||
import httpx
|
||||
import openai
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from openai.lib._pydantic import _ensure_strict_json_schema
|
||||
from openai.lib.streaming.chat import (
|
||||
ChatCompletionStream,
|
||||
ChatCompletionStreamEvent,
|
||||
@@ -20,6 +22,7 @@ from openai.types.chat.chat_completion_chunk import (
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@@ -30,8 +33,9 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
ResponseWithThought,
|
||||
StructuredOutputSupport,
|
||||
ToolDefinition,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
@@ -131,6 +135,8 @@ def completion_with_backoff(
|
||||
aggregated_response += chunk.delta
|
||||
elif chunk.type == "thought.delta":
|
||||
pass
|
||||
elif chunk.type == "tool_calls.function.arguments.done":
|
||||
aggregated_response = json.dumps([{"name": chunk.name, "args": chunk.arguments}])
|
||||
else:
|
||||
# Non-streaming chat completion
|
||||
chunk = client.beta.chat.completions.parse(
|
||||
@@ -190,6 +196,7 @@ async def chat_completion_with_backoff(
|
||||
deepthought=False,
|
||||
model_kwargs: dict = {},
|
||||
tracer: dict = {},
|
||||
tools=None,
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client = openai_async_clients.get(client_key)
|
||||
@@ -258,6 +265,8 @@ async def chat_completion_with_backoff(
|
||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
if tools:
|
||||
model_kwargs["tools"] = tools
|
||||
|
||||
aggregated_response = ""
|
||||
final_chunk = None
|
||||
@@ -327,16 +336,16 @@ async def chat_completion_with_backoff(
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
|
||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||
def get_structured_output_support(model_name: str, api_base_url: str = None) -> StructuredOutputSupport:
|
||||
if model_name.startswith("deepseek-reasoner"):
|
||||
return JsonSupport.NONE
|
||||
return StructuredOutputSupport.NONE
|
||||
if api_base_url:
|
||||
host = urlparse(api_base_url).hostname
|
||||
if host and host.endswith(".ai.azure.com"):
|
||||
return JsonSupport.OBJECT
|
||||
return StructuredOutputSupport.OBJECT
|
||||
if host == "api.deepinfra.com":
|
||||
return JsonSupport.OBJECT
|
||||
return JsonSupport.SCHEMA
|
||||
return StructuredOutputSupport.OBJECT
|
||||
return StructuredOutputSupport.TOOL
|
||||
|
||||
|
||||
def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]:
|
||||
@@ -708,3 +717,47 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
|
||||
if isinstance(content_part, dict) and content_part.get("type") == "text":
|
||||
content_part["text"] += " /no_think"
|
||||
break
|
||||
|
||||
|
||||
def to_openai_tools(tools: List[ToolDefinition]) -> List[Dict] | None:
|
||||
"Transform tool definitions from standard format to OpenAI format."
|
||||
openai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": clean_response_schema(tool.schema),
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
return openai_tools or None
|
||||
|
||||
|
||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||
"""
|
||||
Format response schema to be compatible with OpenAI API.
|
||||
|
||||
Clean the response schema by removing unsupported fields.
|
||||
"""
|
||||
# Normalize schema to OpenAI compatible JSON schema format
|
||||
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
|
||||
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
|
||||
|
||||
# Recursively drop unsupported fields from schema passed to OpenAI API
|
||||
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
||||
fields_to_exclude = ["minItems", "maxItems"]
|
||||
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
|
||||
for _, prop_value in schema_json["properties"].items():
|
||||
if isinstance(prop_value, dict):
|
||||
# Remove specified fields from direct properties
|
||||
for field in fields_to_exclude:
|
||||
prop_value.pop(field, None)
|
||||
# Recursively remove specified fields from child properties
|
||||
if "items" in prop_value and isinstance(prop_value["items"], dict):
|
||||
clean_response_schema(prop_value["items"])
|
||||
|
||||
# Return cleaned schema
|
||||
return schema_json
|
||||
|
||||
@@ -6,11 +6,13 @@ import mimetypes
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import PIL.Image
|
||||
import pyjson5
|
||||
@@ -1149,13 +1151,93 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
|
||||
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
|
||||
|
||||
|
||||
class JsonSupport(int, Enum):
|
||||
class StructuredOutputSupport(int, Enum):
|
||||
NONE = 0
|
||||
OBJECT = 1
|
||||
SCHEMA = 2
|
||||
TOOL = 3
|
||||
|
||||
|
||||
class ResponseWithThought:
|
||||
def __init__(self, response: str = None, thought: str = None):
|
||||
self.response = response
|
||||
self.thought = thought
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
def __init__(self, name: str, description: str, schema: dict):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.schema = schema
|
||||
|
||||
|
||||
def create_tool_definition(
|
||||
schema: Type[BaseModel],
|
||||
name: str = None,
|
||||
description: Optional[str] = None,
|
||||
) -> ToolDefinition:
|
||||
"""
|
||||
Converts a response schema BaseModel class into a normalized tool definition.
|
||||
|
||||
A standard AI provider agnostic tool format to specify tools the model can use.
|
||||
Common logic used across models is kept here. AI provider specific adaptations
|
||||
should be handled in provider code.
|
||||
|
||||
Args:
|
||||
response_schema: The Pydantic BaseModel class to convert.
|
||||
This class defines the response schema for the tool.
|
||||
tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step").
|
||||
tool_description: Optional description for the AI model tool.
|
||||
If None, it attempts to use the Pydantic model's docstring.
|
||||
If that's also missing, a fallback description is generated.
|
||||
|
||||
Returns:
|
||||
A normalized tool definition for AI model APIs.
|
||||
"""
|
||||
raw_schema_dict = schema.model_json_schema()
|
||||
|
||||
name = name or schema.__name__.lower()
|
||||
description = description
|
||||
if description is None:
|
||||
docstring = schema.__doc__
|
||||
if docstring:
|
||||
description = dedent(docstring).strip()
|
||||
else:
|
||||
# Fallback description if no explicit one or docstring is provided
|
||||
description = f"Tool named '{name}' accepts specified parameters."
|
||||
|
||||
# Process properties to inline enums and remove $defs dependency
|
||||
processed_properties = {}
|
||||
original_properties = raw_schema_dict.get("properties", {})
|
||||
defs = raw_schema_dict.get("$defs", {})
|
||||
|
||||
for prop_name, prop_schema in original_properties.items():
|
||||
current_prop_schema = deepcopy(prop_schema) # Work on a copy
|
||||
# Check for enums defined directly in the property for simpler direct enum definitions.
|
||||
if "$ref" in current_prop_schema:
|
||||
ref_path = current_prop_schema["$ref"]
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_name = ref_path.split("/")[-1]
|
||||
if def_name in defs and "enum" in defs[def_name]:
|
||||
enum_def = defs[def_name]
|
||||
current_prop_schema["enum"] = enum_def["enum"]
|
||||
current_prop_schema["type"] = enum_def.get("type", "string")
|
||||
if "description" not in current_prop_schema and "description" in enum_def:
|
||||
current_prop_schema["description"] = enum_def["description"]
|
||||
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
|
||||
|
||||
processed_properties[prop_name] = current_prop_schema
|
||||
|
||||
# Generate the compiled schema dictionary for the tool definition.
|
||||
compiled_schema = {
|
||||
"type": "object",
|
||||
"properties": processed_properties,
|
||||
# Generate content in the order in which the schema properties were defined
|
||||
"property_ordering": list(schema.model_fields.keys()),
|
||||
}
|
||||
|
||||
# Include 'required' fields if specified in the Pydantic model
|
||||
if "required" in raw_schema_dict and raw_schema_dict["required"]:
|
||||
compiled_schema["required"] = raw_schema_dict["required"]
|
||||
|
||||
return ToolDefinition(name=name, description=description, schema=compiled_schema)
|
||||
|
||||
@@ -101,6 +101,7 @@ from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResearchIteration,
|
||||
ResponseWithThought,
|
||||
ToolDefinition,
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
construct_chat_history,
|
||||
@@ -1439,6 +1440,7 @@ async def send_message_to_model_wrapper(
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel = None,
|
||||
tools: List[ToolDefinition] = None,
|
||||
deepthought: bool = False,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
@@ -1506,6 +1508,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
@@ -1517,6 +1520,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
@@ -1528,6 +1532,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
|
||||
Reference in New Issue
Block a user