mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Use json schema to enforce research mode tool pick format
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional, Type
|
||||
|
||||
import yaml
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
@@ -36,6 +38,40 @@ from khoj.utils.rawconfig import LocationData
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlanningResponse(BaseModel):
|
||||
"""
|
||||
Schema for the response from planning agent when deciding the next tool to pick.
|
||||
The tool field is dynamically validated based on available tools.
|
||||
"""
|
||||
|
||||
scratchpad: str = Field(..., description="Reasoning about which tool to use next")
|
||||
query: str = Field(..., description="Detailed query for the selected tool")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]:
|
||||
"""
|
||||
Factory method that creates a customized PlanningResponse model
|
||||
with a properly typed tool field based on available tools.
|
||||
|
||||
Args:
|
||||
tool_options: Dictionary mapping tool names to values
|
||||
|
||||
Returns:
|
||||
A customized PlanningResponse class
|
||||
"""
|
||||
# Create dynamic enum from tool options
|
||||
tool_enum = Enum("ToolEnum", tool_options) # type: ignore
|
||||
|
||||
# Create and return a customized response model with the enum
|
||||
class PlanningResponseWithTool(PlanningResponse):
|
||||
tool: tool_enum = Field(..., description="Name of the tool to use")
|
||||
|
||||
return PlanningResponseWithTool
|
||||
|
||||
|
||||
async def apick_next_tool(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
@@ -61,10 +97,13 @@ async def apick_next_tool(
|
||||
# Skip showing Notes tool as an option if user has no entries
|
||||
if tool == ConversationCommand.Notes and not user_has_entries:
|
||||
continue
|
||||
tool_options[tool.value] = description
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options[tool.name] = tool.value
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
# Create planning reponse model with dynamically populated tool enum class
|
||||
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
|
||||
|
||||
# Construct chat history with user and iteration history with researcher agent for context
|
||||
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
@@ -96,6 +135,7 @@ async def apick_next_tool(
|
||||
query=query,
|
||||
context=function_planning_prompt,
|
||||
response_type="json_object",
|
||||
response_schema=planning_response_model,
|
||||
deepthought=True,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
|
||||
Reference in New Issue
Block a user