Specify min, max items expected in ai response via schema enforcement

Require at least 1 item in lists. Otherwise gemini flash will
sometimes return an empty list. For chat actors where max items is
known, set that as well.

OpenAI API does not support specifying min, max items in response
schema lists, so drop those properties when response schema is
passed. Add other enforcements to response schema to comply with
response schema format expected by OpenAI API.
This commit is contained in:
Debanjum
2025-03-31 00:46:11 +05:30
parent 0eb2d17771
commit ae9ca58ab9
4 changed files with 49 additions and 10 deletions

View File

@@ -4,12 +4,11 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from pydantic import BaseModel
from pydantic import BaseModel, Field
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
format_messages_for_gemini,
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
@@ -98,7 +97,7 @@ def extract_questions_gemini(
messages.append(ChatMessage(content=system_prompt, role="system"))
class DocumentQueries(BaseModel):
queries: List[str]
queries: List[str] = Field(..., min_items=1)
response = gemini_send_message_to_model(
messages,

View File

@@ -4,6 +4,8 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from openai.lib._pydantic import _ensure_strict_json_schema
from pydantic import BaseModel
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
@@ -135,7 +137,16 @@ def send_message_to_model(
model_kwargs = {}
json_support = get_openai_api_json_support(model, api_base_url)
if response_schema and json_support == JsonSupport.SCHEMA:
model_kwargs["response_format"] = response_schema
# Drop unsupported fields from schema passed to OpenAI APi
cleaned_response_schema = clean_response_schema(response_schema)
model_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"schema": cleaned_response_schema,
"name": response_schema.__name__,
"strict": True,
},
}
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
model_kwargs["response_format"] = {"type": response_type}
@@ -257,3 +268,30 @@ def converse_openai(
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
)
def clean_response_schema(schema: BaseModel | dict) -> dict:
"""
Format response schema to be compatible with OpenAI API.
Clean the response schema by removing unsupported fields.
"""
# Normalize schema to OpenAI compatible JSON schema format
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
# Recursively drop unsupported fields from schema passed to OpenAI API
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
fields_to_exclude = ["minItems", "maxItems"]
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
for _, prop_value in schema_json["properties"].items():
if isinstance(prop_value, dict):
# Remove specified fields from direct properties
for field in fields_to_exclude:
prop_value.pop(field, None)
# Recursively remove specified fields from child properties
if "items" in prop_value and isinstance(prop_value["items"], dict):
clean_response_schema(prop_value["items"])
# Return cleaned schema
return schema_json

View File

@@ -900,7 +900,7 @@ Khoj:
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to {max_queries}** google search queries to answer the user's question.
- You will receive the actual chat history as context.
- Add as much context from the chat history as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.

View File

@@ -34,7 +34,7 @@ from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from pydantic import BaseModel
from pydantic import BaseModel, Field
from starlette.authentication import has_required_scope
from starlette.requests import URL
@@ -400,7 +400,7 @@ async def aget_data_sources_and_output_format(
agent_chat_model = agent.chat_model if agent else None
class PickTools(BaseModel):
source: List[str]
source: List[str] = Field(..., min_items=1)
output: str
with timer("Chat actor: Infer information sources to refer", logger):
@@ -489,7 +489,7 @@ async def infer_webpage_urls(
agent_chat_model = agent.chat_model if agent else None
class WebpageUrls(BaseModel):
links: List[str]
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
@@ -535,15 +535,17 @@ async def generate_online_subqueries(
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
max_queries = 3
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
max_queries=max_queries,
current_date=utc_date,
location=location,
username=username,
personality_context=personality_context,
@@ -552,7 +554,7 @@ async def generate_online_subqueries(
agent_chat_model = agent.chat_model if agent else None
class OnlineQueries(BaseModel):
queries: List[str]
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(