diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 655a5baa..3d34573e 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -117,7 +117,7 @@ def extract_questions_anthropic( def anthropic_send_message_to_model( - messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={} + messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={} ): """ Send message to model @@ -130,6 +130,7 @@ def anthropic_send_message_to_model( api_key=api_key, api_base_url=api_base_url, response_type=response_type, + response_schema=response_schema, deepthought=deepthought, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 02658edc..ed525eeb 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -1,9 +1,13 @@ +import json import logging +from copy import deepcopy +from textwrap import dedent from time import perf_counter -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, Dict, List, Optional, Type import anthropic from langchain_core.messages.chat import ChatMessage +from pydantic import BaseModel from tenacity import ( before_sleep_log, retry, @@ -52,6 +56,7 @@ def anthropic_completion_with_backoff( model_kwargs: dict | None = None, max_tokens: int | None = None, response_type: str = "text", + response_schema: BaseModel | None = None, deepthought: bool = False, tracer: dict = {}, ) -> str: @@ -63,13 +68,16 @@ def anthropic_completion_with_backoff( formatted_messages, system = format_messages_for_anthropic(messages, system_prompt) aggregated_response = "" - if response_type == "json_object" and not deepthought: + final_message = None + model_kwargs = model_kwargs or dict() + if response_schema: + tool = create_anthropic_tool_definition(response_schema=response_schema) + model_kwargs["tools"] = [tool] + elif response_type == "json_object" and not deepthought: # Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking. formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{")) aggregated_response += "{" - final_message = None - model_kwargs = model_kwargs or dict() if system: model_kwargs["system"] = system @@ -92,6 +100,12 @@ def anthropic_completion_with_backoff( aggregated_response += text final_message = stream.get_final_message() + # Extract first tool call from final message + for item in final_message.content: + if item.type == "tool_use": + aggregated_response = json.dumps(item.input) + break + # Calculate cost of chat input_tokens = final_message.usage.input_tokens output_tokens = final_message.usage.output_tokens @@ -305,5 +319,74 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st return formatted_messages, system +def create_anthropic_tool_definition( + response_schema: Type[BaseModel], + tool_name: str = None, + tool_description: Optional[str] = None, +) -> anthropic.types.ToolParam: + """ + Converts a response schema BaseModel class into an Anthropic tool definition dictionary. + + This format is expected by Anthropic's API when defining tools the model can use. + + Args: + response_schema: The Pydantic BaseModel class to convert. + This class defines the response schema for the tool. + tool_name: The name for the Anthropic tool (e.g., "get_weather", "plan_next_step"). + tool_description: Optional description for the Anthropic tool. + If None, it attempts to use the Pydantic model's docstring. + If that's also missing, a fallback description is generated. + + Returns: + An tool definition for Anthropic's API. + """ + model_schema = response_schema.model_json_schema() + + name = tool_name or response_schema.__name__.lower() + description = tool_description + if description is None: + docstring = response_schema.__doc__ + if docstring: + description = dedent(docstring).strip() + else: + # Fallback description if no explicit one or docstring is provided + description = f"Tool named '{name}' accepts specified parameters." + + # Process properties to inline enums and remove $defs dependency + processed_properties = {} + original_properties = model_schema.get("properties", {}) + defs = model_schema.get("$defs", {}) + + for prop_name, prop_schema in original_properties.items(): + current_prop_schema = deepcopy(prop_schema) # Work on a copy + # Check for enums defined directly in the property for simpler direct enum definitions. + if "$ref" in current_prop_schema: + ref_path = current_prop_schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + if def_name in defs and "enum" in defs[def_name]: + enum_def = defs[def_name] + current_prop_schema["enum"] = enum_def["enum"] + current_prop_schema["type"] = enum_def.get("type", "string") + if "description" not in current_prop_schema and "description" in enum_def: + current_prop_schema["description"] = enum_def["description"] + del current_prop_schema["$ref"] # Remove the $ref as it's been inlined + + processed_properties[prop_name] = current_prop_schema + + # The input_schema for Anthropic tools is a JSON Schema object. + # Pydantic's model_json_schema() provides most of what's needed. + input_schema = { + "type": "object", + "properties": processed_properties, + } + + # Include 'required' fields if specified in the Pydantic model + if "required" in model_schema and model_schema["required"]: + input_schema["required"] = model_schema["required"] + + return anthropic.types.ToolParam(name=name, description=description, input_schema=input_schema) + + def is_reasoning_model(model_name: str) -> bool: return any(model_name.startswith(model) for model in REASONING_MODELS) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6d0d0064..392f5025 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1232,6 +1232,7 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, deepthought=deepthought, api_base_url=api_base_url, tracer=tracer, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 83ec141a..494a01b6 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -71,6 +71,11 @@ class PlanningResponse(BaseModel): # Create and return a customized response model with the enum class PlanningResponseWithTool(PlanningResponse): + """ + Use the scratchpad to reason about which tool to use next and the query to send to the tool. + Pick tool from provided options and your query to send to the tool. + """ + tool: tool_enum = Field(..., description="Name of the tool to use") query: str = Field(..., description="Detailed query for the selected tool")