Use json schema to enforce research mode tool pick format

This commit is contained in:
Debanjum
2025-03-19 01:07:35 +05:30
parent 6980014838
commit 2c53eb9de1

View File

@@ -1,10 +1,12 @@
import logging
import os
from datetime import datetime
from typing import Callable, Dict, List, Optional
from enum import Enum
from typing import Callable, Dict, List, Optional, Type
import yaml
from fastapi import Request
from pydantic import BaseModel, Field
from khoj.database.adapters import EntryAdapters
from khoj.database.models import Agent, KhojUser
@@ -36,6 +38,40 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
class PlanningResponse(BaseModel):
"""
Schema for the response from planning agent when deciding the next tool to pick.
The tool field is dynamically validated based on available tools.
"""
scratchpad: str = Field(..., description="Reasoning about which tool to use next")
query: str = Field(..., description="Detailed query for the selected tool")
class Config:
arbitrary_types_allowed = True
@classmethod
def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]:
"""
Factory method that creates a customized PlanningResponse model
with a properly typed tool field based on available tools.
Args:
tool_options: Dictionary mapping tool names to values
Returns:
A customized PlanningResponse class
"""
# Create dynamic enum from tool options
tool_enum = Enum("ToolEnum", tool_options) # type: ignore
# Create and return a customized response model with the enum
class PlanningResponseWithTool(PlanningResponse):
tool: tool_enum = Field(..., description="Name of the tool to use")
return PlanningResponseWithTool
async def apick_next_tool(
query: str,
conversation_history: dict,
@@ -61,10 +97,13 @@ async def apick_next_tool(
# Skip showing Notes tool as an option if user has no entries
if tool == ConversationCommand.Notes and not user_has_entries:
continue
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options[tool.name] = tool.value
tool_options_str += f'- "{tool.value}": "{description}"\n'
# Create planning reponse model with dynamically populated tool enum class
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
# Construct chat history with user and iteration history with researcher agent for context
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
@@ -96,6 +135,7 @@ async def apick_next_tool(
query=query,
context=function_planning_prompt,
response_type="json_object",
response_schema=planning_response_model,
deepthought=True,
user=user,
query_images=query_images,