From ae9ca58ab91eaa481cfde1ee04efc6826c9193eb Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 31 Mar 2025 00:46:11 +0530 Subject: [PATCH] Specify min, max items expected in ai response via schema enforcement Require at least 1 item in lists. Otherwise gemini flash will sometimes return an empty list. For chat actors where max items is known, set that as well. OpenAI API does not support specifying min, max items in response schema lists, so drop those properties when response schema is passed. Add other enforcements to response schema to comply with response schema format expected by OpenAI API. --- .../conversation/google/gemini_chat.py | 5 +-- src/khoj/processor/conversation/openai/gpt.py | 40 ++++++++++++++++++- src/khoj/processor/conversation/prompts.py | 2 +- src/khoj/routers/helpers.py | 12 +++--- 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index eec24d92..75017d63 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -4,12 +4,11 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage -from pydantic import BaseModel +from pydantic import BaseModel, Field from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( - format_messages_for_gemini, gemini_chat_completion_with_backoff, gemini_completion_with_backoff, ) @@ -98,7 +97,7 @@ def extract_questions_gemini( messages.append(ChatMessage(content=system_prompt, role="system")) class DocumentQueries(BaseModel): - queries: List[str] + queries: List[str] = Field(..., min_items=1) response = gemini_send_message_to_model( messages, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index cc8ec027..11e6a03d 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -4,6 +4,8 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage +from openai.lib._pydantic import _ensure_strict_json_schema +from pydantic import BaseModel from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts @@ -135,7 +137,16 @@ def send_message_to_model( model_kwargs = {} json_support = get_openai_api_json_support(model, api_base_url) if response_schema and json_support == JsonSupport.SCHEMA: - model_kwargs["response_format"] = response_schema + # Drop unsupported fields from schema passed to OpenAI APi + cleaned_response_schema = clean_response_schema(response_schema) + model_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "schema": cleaned_response_schema, + "name": response_schema.__name__, + "strict": True, + }, + } elif response_type == "json_object" and json_support == JsonSupport.OBJECT: model_kwargs["response_format"] = {"type": response_type} @@ -257,3 +268,30 @@ def converse_openai( model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, ) + + +def clean_response_schema(schema: BaseModel | dict) -> dict: + """ + Format response schema to be compatible with OpenAI API. + + Clean the response schema by removing unsupported fields. + """ + # Normalize schema to OpenAI compatible JSON schema format + schema_json = schema if isinstance(schema, dict) else schema.model_json_schema() + schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json) + + # Recursively drop unsupported fields from schema passed to OpenAI API + # See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas + fields_to_exclude = ["minItems", "maxItems"] + if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict): + for _, prop_value in schema_json["properties"].items(): + if isinstance(prop_value, dict): + # Remove specified fields from direct properties + for field in fields_to_exclude: + prop_value.pop(field, None) + # Recursively remove specified fields from child properties + if "items" in prop_value and isinstance(prop_value["items"], dict): + clean_response_schema(prop_value["items"]) + + # Return cleaned schema + return schema_json diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index e59ab304..ceb8093e 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -900,7 +900,7 @@ Khoj: online_search_conversation_subqueries = PromptTemplate.from_template( """ -You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question. +You are Khoj, an advanced web search assistant. You are tasked with constructing **up to {max_queries}** google search queries to answer the user's question. - You will receive the actual chat history as context. - Add as much context from the chat history as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 28808ecd..e1d69a1e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -34,7 +34,7 @@ from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from fastapi import Depends, Header, HTTPException, Request, UploadFile -from pydantic import BaseModel +from pydantic import BaseModel, Field from starlette.authentication import has_required_scope from starlette.requests import URL @@ -400,7 +400,7 @@ async def aget_data_sources_and_output_format( agent_chat_model = agent.chat_model if agent else None class PickTools(BaseModel): - source: List[str] + source: List[str] = Field(..., min_items=1) output: str with timer("Chat actor: Infer information sources to refer", logger): @@ -489,7 +489,7 @@ async def infer_webpage_urls( agent_chat_model = agent.chat_model if agent else None class WebpageUrls(BaseModel): - links: List[str] + links: List[str] = Field(..., min_items=1, max_items=max_webpages) with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( @@ -535,15 +535,17 @@ async def generate_online_subqueries( username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) + max_queries = 3 utc_date = datetime.utcnow().strftime("%Y-%m-%d") personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) online_queries_prompt = prompts.online_search_conversation_subqueries.format( - current_date=utc_date, query=q, chat_history=chat_history, + max_queries=max_queries, + current_date=utc_date, location=location, username=username, personality_context=personality_context, @@ -552,7 +554,7 @@ async def generate_online_subqueries( agent_chat_model = agent.chat_model if agent else None class OnlineQueries(BaseModel): - queries: List[str] + queries: List[str] = Field(..., min_items=1, max_items=max_queries) with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper(