mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Specify min, max items expected in ai response via schema enforcement
Require at least 1 item in lists. Otherwise gemini flash will sometimes return an empty list. For chat actors where max items is known, set that as well. OpenAI API does not support specifying min, max items in response schema lists, so drop those properties when response schema is passed. Add other enforcements to response schema to comply with response schema format expected by OpenAI API.
This commit is contained in:
@@ -4,12 +4,11 @@ from typing import Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
@@ -98,7 +97,7 @@ def extract_questions_gemini(
|
||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||
|
||||
class DocumentQueries(BaseModel):
|
||||
queries: List[str]
|
||||
queries: List[str] = Field(..., min_items=1)
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages,
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
from openai.lib._pydantic import _ensure_strict_json_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
@@ -135,7 +137,16 @@ def send_message_to_model(
|
||||
model_kwargs = {}
|
||||
json_support = get_openai_api_json_support(model, api_base_url)
|
||||
if response_schema and json_support == JsonSupport.SCHEMA:
|
||||
model_kwargs["response_format"] = response_schema
|
||||
# Drop unsupported fields from schema passed to OpenAI APi
|
||||
cleaned_response_schema = clean_response_schema(response_schema)
|
||||
model_kwargs["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"schema": cleaned_response_schema,
|
||||
"name": response_schema.__name__,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
|
||||
model_kwargs["response_format"] = {"type": response_type}
|
||||
|
||||
@@ -257,3 +268,30 @@ def converse_openai(
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||
"""
|
||||
Format response schema to be compatible with OpenAI API.
|
||||
|
||||
Clean the response schema by removing unsupported fields.
|
||||
"""
|
||||
# Normalize schema to OpenAI compatible JSON schema format
|
||||
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
|
||||
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
|
||||
|
||||
# Recursively drop unsupported fields from schema passed to OpenAI API
|
||||
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
||||
fields_to_exclude = ["minItems", "maxItems"]
|
||||
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
|
||||
for _, prop_value in schema_json["properties"].items():
|
||||
if isinstance(prop_value, dict):
|
||||
# Remove specified fields from direct properties
|
||||
for field in fields_to_exclude:
|
||||
prop_value.pop(field, None)
|
||||
# Recursively remove specified fields from child properties
|
||||
if "items" in prop_value and isinstance(prop_value["items"], dict):
|
||||
clean_response_schema(prop_value["items"])
|
||||
|
||||
# Return cleaned schema
|
||||
return schema_json
|
||||
|
||||
@@ -900,7 +900,7 @@ Khoj:
|
||||
|
||||
online_search_conversation_subqueries = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
|
||||
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to {max_queries}** google search queries to answer the user's question.
|
||||
- You will receive the actual chat history as context.
|
||||
- Add as much context from the chat history as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
|
||||
@@ -34,7 +34,7 @@ from apscheduler.job import Job
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
|
||||
@@ -400,7 +400,7 @@ async def aget_data_sources_and_output_format(
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class PickTools(BaseModel):
|
||||
source: List[str]
|
||||
source: List[str] = Field(..., min_items=1)
|
||||
output: str
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
@@ -489,7 +489,7 @@ async def infer_webpage_urls(
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class WebpageUrls(BaseModel):
|
||||
links: List[str]
|
||||
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
|
||||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
@@ -535,15 +535,17 @@ async def generate_online_subqueries(
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
max_queries = 3
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
max_queries=max_queries,
|
||||
current_date=utc_date,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
@@ -552,7 +554,7 @@ async def generate_online_subqueries(
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class OnlineQueries(BaseModel):
|
||||
queries: List[str]
|
||||
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
|
||||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
|
||||
Reference in New Issue
Block a user