mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Use json schema to enforce research mode tool pick format
This commit is contained in:
@@ -1,10 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable, Dict, List, Optional
|
from enum import Enum
|
||||||
|
from typing import Callable, Dict, List, Optional, Type
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from khoj.database.adapters import EntryAdapters
|
from khoj.database.adapters import EntryAdapters
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, KhojUser
|
||||||
@@ -36,6 +38,40 @@ from khoj.utils.rawconfig import LocationData
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PlanningResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Schema for the response from planning agent when deciding the next tool to pick.
|
||||||
|
The tool field is dynamically validated based on available tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
scratchpad: str = Field(..., description="Reasoning about which tool to use next")
|
||||||
|
query: str = Field(..., description="Detailed query for the selected tool")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]:
|
||||||
|
"""
|
||||||
|
Factory method that creates a customized PlanningResponse model
|
||||||
|
with a properly typed tool field based on available tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_options: Dictionary mapping tool names to values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A customized PlanningResponse class
|
||||||
|
"""
|
||||||
|
# Create dynamic enum from tool options
|
||||||
|
tool_enum = Enum("ToolEnum", tool_options) # type: ignore
|
||||||
|
|
||||||
|
# Create and return a customized response model with the enum
|
||||||
|
class PlanningResponseWithTool(PlanningResponse):
|
||||||
|
tool: tool_enum = Field(..., description="Name of the tool to use")
|
||||||
|
|
||||||
|
return PlanningResponseWithTool
|
||||||
|
|
||||||
|
|
||||||
async def apick_next_tool(
|
async def apick_next_tool(
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
@@ -61,10 +97,13 @@ async def apick_next_tool(
|
|||||||
# Skip showing Notes tool as an option if user has no entries
|
# Skip showing Notes tool as an option if user has no entries
|
||||||
if tool == ConversationCommand.Notes and not user_has_entries:
|
if tool == ConversationCommand.Notes and not user_has_entries:
|
||||||
continue
|
continue
|
||||||
tool_options[tool.value] = description
|
|
||||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||||
|
tool_options[tool.name] = tool.value
|
||||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||||
|
|
||||||
|
# Create planning reponse model with dynamically populated tool enum class
|
||||||
|
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
|
||||||
|
|
||||||
# Construct chat history with user and iteration history with researcher agent for context
|
# Construct chat history with user and iteration history with researcher agent for context
|
||||||
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
||||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||||
@@ -96,6 +135,7 @@ async def apick_next_tool(
|
|||||||
query=query,
|
query=query,
|
||||||
context=function_planning_prompt,
|
context=function_planning_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
|
response_schema=planning_response_model,
|
||||||
deepthought=True,
|
deepthought=True,
|
||||||
user=user,
|
user=user,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
|||||||
Reference in New Issue
Block a user