diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 86c5d516..1c3b7166 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -22,7 +22,15 @@ logger = logging.getLogger(__name__) def anthropic_send_message_to_model( - messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={} + messages, + api_key, + api_base_url, + model, + response_type="text", + response_schema=None, + tools=None, + deepthought=False, + tracer={}, ): """ Send message to model @@ -36,6 +44,7 @@ def anthropic_send_message_to_model( api_base_url=api_base_url, response_type=response_type, response_schema=response_schema, + tools=tools, deepthought=deepthought, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 796a90da..a9fd1db3 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -1,9 +1,7 @@ import json import logging -from copy import deepcopy -from textwrap import dedent from time import perf_counter -from typing import AsyncGenerator, Dict, List, Optional, Type +from typing import AsyncGenerator, Dict, List import anthropic from langchain_core.messages.chat import ChatMessage @@ -18,7 +16,9 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ResponseWithThought, + ToolDefinition, commit_conversation_trace, + create_tool_definition, get_image_from_base64, get_image_from_url, ) @@ -57,6 +57,7 @@ def anthropic_completion_with_backoff( max_tokens: int | None = None, response_type: str = "text", response_schema: BaseModel | None = None, + tools: List[ToolDefinition] = None, deepthought: bool = False, tracer: dict = {}, ) -> str: @@ -70,9 +71,19 @@ def anthropic_completion_with_backoff( aggregated_response = "" final_message = None model_kwargs = model_kwargs or dict() - if response_schema: - tool = create_anthropic_tool_definition(response_schema=response_schema) - model_kwargs["tools"] = [tool] + + # Configure structured output + if tools: + # Convert tools to Anthropic format + model_kwargs["tools"] = [ + anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema) + for tool in tools + ] + elif response_schema: + tool = create_tool_definition(response_schema) + model_kwargs["tools"] = [ + anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema) + ] elif response_type == "json_object" and not (is_reasoning_model(model_name) and deepthought): # Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking. formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{")) @@ -103,7 +114,7 @@ def anthropic_completion_with_backoff( # Extract first tool call from final message for item in final_message.content: if item.type == "tool_use": - aggregated_response = json.dumps(item.input) + aggregated_response = json.dumps([{"name": item.name, "args": item.input}]) break # Calculate cost of chat @@ -326,74 +337,5 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st return formatted_messages, system -def create_anthropic_tool_definition( - response_schema: Type[BaseModel], - tool_name: str = None, - tool_description: Optional[str] = None, -) -> anthropic.types.ToolParam: - """ - Converts a response schema BaseModel class into an Anthropic tool definition dictionary. - - This format is expected by Anthropic's API when defining tools the model can use. - - Args: - response_schema: The Pydantic BaseModel class to convert. - This class defines the response schema for the tool. - tool_name: The name for the Anthropic tool (e.g., "get_weather", "plan_next_step"). - tool_description: Optional description for the Anthropic tool. - If None, it attempts to use the Pydantic model's docstring. - If that's also missing, a fallback description is generated. - - Returns: - An tool definition for Anthropic's API. - """ - model_schema = response_schema.model_json_schema() - - name = tool_name or response_schema.__name__.lower() - description = tool_description - if description is None: - docstring = response_schema.__doc__ - if docstring: - description = dedent(docstring).strip() - else: - # Fallback description if no explicit one or docstring is provided - description = f"Tool named '{name}' accepts specified parameters." - - # Process properties to inline enums and remove $defs dependency - processed_properties = {} - original_properties = model_schema.get("properties", {}) - defs = model_schema.get("$defs", {}) - - for prop_name, prop_schema in original_properties.items(): - current_prop_schema = deepcopy(prop_schema) # Work on a copy - # Check for enums defined directly in the property for simpler direct enum definitions. - if "$ref" in current_prop_schema: - ref_path = current_prop_schema["$ref"] - if ref_path.startswith("#/$defs/"): - def_name = ref_path.split("/")[-1] - if def_name in defs and "enum" in defs[def_name]: - enum_def = defs[def_name] - current_prop_schema["enum"] = enum_def["enum"] - current_prop_schema["type"] = enum_def.get("type", "string") - if "description" not in current_prop_schema and "description" in enum_def: - current_prop_schema["description"] = enum_def["description"] - del current_prop_schema["$ref"] # Remove the $ref as it's been inlined - - processed_properties[prop_name] = current_prop_schema - - # The input_schema for Anthropic tools is a JSON Schema object. - # Pydantic's model_json_schema() provides most of what's needed. - input_schema = { - "type": "object", - "properties": processed_properties, - } - - # Include 'required' fields if specified in the Pydantic model - if "required" in model_schema and model_schema["required"]: - input_schema["required"] = model_schema["required"] - - return anthropic.types.ToolParam(name=name, description=description, input_schema=input_schema) - - def is_reasoning_model(model_name: str) -> bool: return any(model_name.startswith(model) for model in REASONING_MODELS) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 6bd04790..aed144c0 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -28,6 +28,7 @@ def gemini_send_message_to_model( api_base_url=None, response_type="text", response_schema=None, + tools=None, model_kwargs=None, deepthought=False, tracer={}, @@ -37,8 +38,10 @@ def gemini_send_message_to_model( """ model_kwargs = {} + if tools: + model_kwargs["tools"] = tools # Monitor for flakiness in 1.5+ models. This would cause unwanted behavior and terminate response early in 1.5 models. - if response_type == "json_object" and not model.startswith("gemini-1.5"): + elif response_type == "json_object" and not model.startswith("gemini-1.5"): model_kwargs["response_mime_type"] = "application/json" if response_schema: model_kwargs["response_schema"] = response_schema diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 760b314e..14eb5119 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,9 +1,10 @@ +import json import logging import os import random from copy import deepcopy from time import perf_counter -from typing import AsyncGenerator, AsyncIterator, Dict +from typing import AsyncGenerator, AsyncIterator, Dict, List import httpx from google import genai @@ -22,6 +23,7 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ResponseWithThought, + ToolDefinition, commit_conversation_trace, get_image_from_base64, get_image_from_url, @@ -95,7 +97,7 @@ def gemini_completion_with_backoff( temperature=1.2, api_key=None, api_base_url: str = None, - model_kwargs=None, + model_kwargs={}, deepthought=False, tracer={}, ) -> str: @@ -107,9 +109,12 @@ def gemini_completion_with_backoff( formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt) response_thoughts: str | None = None - # format model response schema + # Configure structured output + tools = None response_schema = None - if model_kwargs and model_kwargs.get("response_schema"): + if model_kwargs.get("tools"): + tools = to_gemini_tools(model_kwargs["tools"]) + elif model_kwargs.get("response_schema"): response_schema = clean_response_schema(model_kwargs["response_schema"]) thinking_config = None @@ -127,8 +132,9 @@ def gemini_completion_with_backoff( thinking_config=thinking_config, max_output_tokens=max_output_tokens, safety_settings=SAFETY_SETTINGS, - response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", + response_mime_type=model_kwargs.get("response_mime_type", "text/plain"), response_schema=response_schema, + tools=tools, seed=seed, top_p=0.95, http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}), @@ -137,7 +143,14 @@ def gemini_completion_with_backoff( try: # Generate the response response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages) - response_text = response.text + if response.function_calls: + function_calls = [ + {"name": function_call.name, "args": function_call.args} for function_call in response.function_calls + ] + response_text = json.dumps(function_calls) + else: + # If no function calls, use the text response + response_text = response.text except gerrors.ClientError as e: response = None response_text, _ = handle_gemini_response(e.args) @@ -404,3 +417,21 @@ def is_reasoning_model(model_name: str) -> bool: Check if the model is a reasoning model. """ return model_name.startswith("gemini-2.5") + + +def to_gemini_tools(tools: List[ToolDefinition]) -> List[gtypes.ToolDict] | None: + "Transform tool definitions from standard format to Gemini format." + gemini_tools = [ + gtypes.ToolDict( + function_declarations=[ + gtypes.FunctionDeclarationDict( + name=tool.name, + description=tool.description, + parameters=tool.schema, + ) + for tool in tools + ] + ) + ] + + return gemini_tools or None diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 56af7bde..f123ad63 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,21 +1,21 @@ import logging from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional - -from openai.lib._pydantic import _ensure_strict_json_schema -from pydantic import BaseModel +from typing import Any, AsyncGenerator, Dict, List, Optional from khoj.database.models import Agent, ChatMessageModel, ChatModel from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, + clean_response_schema, completion_with_backoff, - get_openai_api_json_support, + get_structured_output_support, + to_openai_tools, ) from khoj.processor.conversation.utils import ( - JsonSupport, OperatorRun, ResponseWithThought, + StructuredOutputSupport, + ToolDefinition, generate_chatml_messages_with_context, messages_to_print, ) @@ -32,6 +32,7 @@ def send_message_to_model( model, response_type="text", response_schema=None, + tools: list[ToolDefinition] = None, deepthought=False, api_base_url=None, tracer: dict = {}, @@ -40,9 +41,11 @@ def send_message_to_model( Send message to model """ - model_kwargs = {} - json_support = get_openai_api_json_support(model, api_base_url) - if response_schema and json_support == JsonSupport.SCHEMA: + model_kwargs: Dict[str, Any] = {} + json_support = get_structured_output_support(model, api_base_url) + if tools and json_support == StructuredOutputSupport.TOOL: + model_kwargs["tools"] = to_openai_tools(tools) + elif response_schema and json_support >= StructuredOutputSupport.SCHEMA: # Drop unsupported fields from schema passed to OpenAI APi cleaned_response_schema = clean_response_schema(response_schema) model_kwargs["response_format"] = { @@ -53,7 +56,7 @@ def send_message_to_model( "strict": True, }, } - elif response_type == "json_object" and json_support == JsonSupport.OBJECT: + elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT: model_kwargs["response_format"] = {"type": response_type} # Get Response from GPT @@ -171,30 +174,3 @@ async def converse_openai( tracer=tracer, ): yield chunk - - -def clean_response_schema(schema: BaseModel | dict) -> dict: - """ - Format response schema to be compatible with OpenAI API. - - Clean the response schema by removing unsupported fields. - """ - # Normalize schema to OpenAI compatible JSON schema format - schema_json = schema if isinstance(schema, dict) else schema.model_json_schema() - schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json) - - # Recursively drop unsupported fields from schema passed to OpenAI API - # See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas - fields_to_exclude = ["minItems", "maxItems"] - if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict): - for _, prop_value in schema_json["properties"].items(): - if isinstance(prop_value, dict): - # Remove specified fields from direct properties - for field in fields_to_exclude: - prop_value.pop(field, None) - # Recursively remove specified fields from child properties - if "items" in prop_value and isinstance(prop_value["items"], dict): - clean_response_schema(prop_value["items"]) - - # Return cleaned schema - return schema_json diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 440d3286..b0c311b7 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,3 +1,4 @@ +import json import logging import os from copy import deepcopy @@ -9,6 +10,7 @@ from urllib.parse import urlparse import httpx import openai from langchain_core.messages.chat import ChatMessage +from openai.lib._pydantic import _ensure_strict_json_schema from openai.lib.streaming.chat import ( ChatCompletionStream, ChatCompletionStreamEvent, @@ -20,6 +22,7 @@ from openai.types.chat.chat_completion_chunk import ( Choice, ChoiceDelta, ) +from pydantic import BaseModel from tenacity import ( before_sleep_log, retry, @@ -30,8 +33,9 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( - JsonSupport, ResponseWithThought, + StructuredOutputSupport, + ToolDefinition, commit_conversation_trace, ) from khoj.utils.helpers import ( @@ -131,6 +135,8 @@ def completion_with_backoff( aggregated_response += chunk.delta elif chunk.type == "thought.delta": pass + elif chunk.type == "tool_calls.function.arguments.done": + aggregated_response = json.dumps([{"name": chunk.name, "args": chunk.arguments}]) else: # Non-streaming chat completion chunk = client.beta.chat.completions.parse( @@ -190,6 +196,7 @@ async def chat_completion_with_backoff( deepthought=False, model_kwargs: dict = {}, tracer: dict = {}, + tools=None, ) -> AsyncGenerator[ResponseWithThought, None]: client_key = f"{openai_api_key}--{api_base_url}" client = openai_async_clients.get(client_key) @@ -258,6 +265,8 @@ async def chat_completion_with_backoff( read_timeout = 300 if is_local_api(api_base_url) else 60 if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) + if tools: + model_kwargs["tools"] = tools aggregated_response = "" final_chunk = None @@ -327,16 +336,16 @@ async def chat_completion_with_backoff( commit_conversation_trace(messages, aggregated_response, tracer) -def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: +def get_structured_output_support(model_name: str, api_base_url: str = None) -> StructuredOutputSupport: if model_name.startswith("deepseek-reasoner"): - return JsonSupport.NONE + return StructuredOutputSupport.NONE if api_base_url: host = urlparse(api_base_url).hostname if host and host.endswith(".ai.azure.com"): - return JsonSupport.OBJECT + return StructuredOutputSupport.OBJECT if host == "api.deepinfra.com": - return JsonSupport.OBJECT - return JsonSupport.SCHEMA + return StructuredOutputSupport.OBJECT + return StructuredOutputSupport.TOOL def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]: @@ -708,3 +717,47 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None: if isinstance(content_part, dict) and content_part.get("type") == "text": content_part["text"] += " /no_think" break + + +def to_openai_tools(tools: List[ToolDefinition]) -> List[Dict] | None: + "Transform tool definitions from standard format to OpenAI format." + openai_tools = [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": clean_response_schema(tool.schema), + }, + } + for tool in tools + ] + + return openai_tools or None + + +def clean_response_schema(schema: BaseModel | dict) -> dict: + """ + Format response schema to be compatible with OpenAI API. + + Clean the response schema by removing unsupported fields. + """ + # Normalize schema to OpenAI compatible JSON schema format + schema_json = schema if isinstance(schema, dict) else schema.model_json_schema() + schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json) + + # Recursively drop unsupported fields from schema passed to OpenAI API + # See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas + fields_to_exclude = ["minItems", "maxItems"] + if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict): + for _, prop_value in schema_json["properties"].items(): + if isinstance(prop_value, dict): + # Remove specified fields from direct properties + for field in fields_to_exclude: + prop_value.pop(field, None) + # Recursively remove specified fields from child properties + if "items" in prop_value and isinstance(prop_value["items"], dict): + clean_response_schema(prop_value["items"]) + + # Return cleaned schema + return schema_json diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a3006aae..4b5f0f38 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -6,11 +6,13 @@ import mimetypes import os import re import uuid +from copy import deepcopy from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from textwrap import dedent +from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union import PIL.Image import pyjson5 @@ -1149,13 +1151,93 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str: return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages]) -class JsonSupport(int, Enum): +class StructuredOutputSupport(int, Enum): NONE = 0 OBJECT = 1 SCHEMA = 2 + TOOL = 3 class ResponseWithThought: def __init__(self, response: str = None, thought: str = None): self.response = response self.thought = thought + + +class ToolDefinition: + def __init__(self, name: str, description: str, schema: dict): + self.name = name + self.description = description + self.schema = schema + + +def create_tool_definition( + schema: Type[BaseModel], + name: str = None, + description: Optional[str] = None, +) -> ToolDefinition: + """ + Converts a response schema BaseModel class into a normalized tool definition. + + A standard AI provider agnostic tool format to specify tools the model can use. + Common logic used across models is kept here. AI provider specific adaptations + should be handled in provider code. + + Args: + response_schema: The Pydantic BaseModel class to convert. + This class defines the response schema for the tool. + tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step"). + tool_description: Optional description for the AI model tool. + If None, it attempts to use the Pydantic model's docstring. + If that's also missing, a fallback description is generated. + + Returns: + A normalized tool definition for AI model APIs. + """ + raw_schema_dict = schema.model_json_schema() + + name = name or schema.__name__.lower() + description = description + if description is None: + docstring = schema.__doc__ + if docstring: + description = dedent(docstring).strip() + else: + # Fallback description if no explicit one or docstring is provided + description = f"Tool named '{name}' accepts specified parameters." + + # Process properties to inline enums and remove $defs dependency + processed_properties = {} + original_properties = raw_schema_dict.get("properties", {}) + defs = raw_schema_dict.get("$defs", {}) + + for prop_name, prop_schema in original_properties.items(): + current_prop_schema = deepcopy(prop_schema) # Work on a copy + # Check for enums defined directly in the property for simpler direct enum definitions. + if "$ref" in current_prop_schema: + ref_path = current_prop_schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + if def_name in defs and "enum" in defs[def_name]: + enum_def = defs[def_name] + current_prop_schema["enum"] = enum_def["enum"] + current_prop_schema["type"] = enum_def.get("type", "string") + if "description" not in current_prop_schema and "description" in enum_def: + current_prop_schema["description"] = enum_def["description"] + del current_prop_schema["$ref"] # Remove the $ref as it's been inlined + + processed_properties[prop_name] = current_prop_schema + + # Generate the compiled schema dictionary for the tool definition. + compiled_schema = { + "type": "object", + "properties": processed_properties, + # Generate content in the order in which the schema properties were defined + "property_ordering": list(schema.model_fields.keys()), + } + + # Include 'required' fields if specified in the Pydantic model + if "required" in raw_schema_dict and raw_schema_dict["required"]: + compiled_schema["required"] = raw_schema_dict["required"] + + return ToolDefinition(name=name, description=description, schema=compiled_schema) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index bb75da13..5b0b83e3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -101,6 +101,7 @@ from khoj.processor.conversation.utils import ( OperatorRun, ResearchIteration, ResponseWithThought, + ToolDefinition, clean_json, clean_mermaidjs, construct_chat_history, @@ -1439,6 +1440,7 @@ async def send_message_to_model_wrapper( system_message: str = "", response_type: str = "text", response_schema: BaseModel = None, + tools: List[ToolDefinition] = None, deepthought: bool = False, user: KhojUser = None, query_images: List[str] = None, @@ -1506,6 +1508,7 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, response_schema=response_schema, + tools=tools, deepthought=deepthought, api_base_url=api_base_url, tracer=tracer, @@ -1517,6 +1520,7 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, response_schema=response_schema, + tools=tools, deepthought=deepthought, api_base_url=api_base_url, tracer=tracer, @@ -1528,6 +1532,7 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, response_schema=response_schema, + tools=tools, deepthought=deepthought, api_base_url=api_base_url, tracer=tracer,