diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 07a66002..6e2493c5 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -1,10 +1,12 @@ import logging import os from datetime import datetime -from typing import Callable, Dict, List, Optional +from enum import Enum +from typing import Callable, Dict, List, Optional, Type import yaml from fastapi import Request +from pydantic import BaseModel, Field from khoj.database.adapters import EntryAdapters from khoj.database.models import Agent, KhojUser @@ -36,6 +38,40 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) +class PlanningResponse(BaseModel): + """ + Schema for the response from planning agent when deciding the next tool to pick. + The tool field is dynamically validated based on available tools. + """ + + scratchpad: str = Field(..., description="Reasoning about which tool to use next") + query: str = Field(..., description="Detailed query for the selected tool") + + class Config: + arbitrary_types_allowed = True + + @classmethod + def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]: + """ + Factory method that creates a customized PlanningResponse model + with a properly typed tool field based on available tools. + + Args: + tool_options: Dictionary mapping tool names to values + + Returns: + A customized PlanningResponse class + """ + # Create dynamic enum from tool options + tool_enum = Enum("ToolEnum", tool_options) # type: ignore + + # Create and return a customized response model with the enum + class PlanningResponseWithTool(PlanningResponse): + tool: tool_enum = Field(..., description="Name of the tool to use") + + return PlanningResponseWithTool + + async def apick_next_tool( query: str, conversation_history: dict, @@ -61,10 +97,13 @@ async def apick_next_tool( # Skip showing Notes tool as an option if user has no entries if tool == ConversationCommand.Notes and not user_has_entries: continue - tool_options[tool.value] = description if len(agent_tools) == 0 or tool.value in agent_tools: + tool_options[tool.name] = tool.value tool_options_str += f'- "{tool.value}": "{description}"\n' + # Create planning reponse model with dynamically populated tool enum class + planning_response_model = PlanningResponse.create_model_with_enum(tool_options) + # Construct chat history with user and iteration history with researcher agent for context chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj") previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) @@ -96,6 +135,7 @@ async def apick_next_tool( query=query, context=function_planning_prompt, response_type="json_object", + response_schema=planning_response_model, deepthought=True, user=user, query_images=query_images,