Improve Gemini Response Reliability (#1148)

- Specify min, max number of list items expected in AI response via JSON schema enforcement. Used by Gemini models
- Warn and drop invalid/empty messages when format messages for Gemini models
- Make Gemini response adhere to the order of the schema property definitions
- Improve agent creation safety checker by using response schema, better prompt
This commit is contained in:
Debanjum
2025-04-03 13:42:35 +05:30
committed by GitHub
5 changed files with 92 additions and 14 deletions

View File

@@ -4,12 +4,11 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from pydantic import BaseModel
from pydantic import BaseModel, Field
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
format_messages_for_gemini,
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
@@ -98,7 +97,7 @@ def extract_questions_gemini(
messages.append(ChatMessage(content=system_prompt, role="system"))
class DocumentQueries(BaseModel):
queries: List[str]
queries: List[str] = Field(..., min_items=1)
response = gemini_send_message_to_model(
messages,

View File

@@ -9,6 +9,7 @@ from google import genai
from google.genai import errors as gerrors
from google.genai import types as gtypes
from langchain.schema import ChatMessage
from pydantic import BaseModel
from tenacity import (
before_sleep_log,
retry,
@@ -86,6 +87,11 @@ def gemini_completion_with_backoff(
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
# format model response schema
response_schema = None
if model_kwargs and "response_schema" in model_kwargs:
response_schema = clean_response_schema(model_kwargs["response_schema"])
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
@@ -93,7 +99,7 @@ def gemini_completion_with_backoff(
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
response_schema=response_schema,
seed=seed,
)
@@ -294,11 +300,21 @@ def format_messages_for_gemini(
else:
image = get_image_from_base64(image_data, type="bytes")
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
elif not is_none_or_empty(item.get("text")):
message_content += [gtypes.Part.from_text(text=item["text"])]
else:
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]
logger.error(f"Dropping invalid message content part: {item}")
if not message_content:
logger.error(f"Dropping empty message content")
messages.remove(message)
continue
message.content = message_content
elif isinstance(message.content, str):
message.content = [gtypes.Part.from_text(text=message.content)]
else:
logger.error(f"Dropping invalid type: {type(message.content)} of message content: {message.content}")
messages.remove(message)
continue
if message.role == "assistant":
message.role = "model"
@@ -308,3 +324,18 @@ def format_messages_for_gemini(
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
return formatted_messages, system_prompt
def clean_response_schema(response_schema: BaseModel) -> dict:
"""
Convert Pydantic model to dict for Gemini response schema.
Ensure response schema adheres to the order of the original property definition.
"""
# Convert Pydantic model to dict
response_schema_dict = response_schema.model_json_schema()
# Get field names in original definition order
field_names = list(response_schema.model_fields.keys())
# Generate content in the order in which the schema properties were defined
response_schema_dict["property_ordering"] = field_names
return response_schema_dict

View File

@@ -4,6 +4,8 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from openai.lib._pydantic import _ensure_strict_json_schema
from pydantic import BaseModel
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
@@ -135,7 +137,16 @@ def send_message_to_model(
model_kwargs = {}
json_support = get_openai_api_json_support(model, api_base_url)
if response_schema and json_support == JsonSupport.SCHEMA:
model_kwargs["response_format"] = response_schema
# Drop unsupported fields from schema passed to OpenAI APi
cleaned_response_schema = clean_response_schema(response_schema)
model_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"schema": cleaned_response_schema,
"name": response_schema.__name__,
"strict": True,
},
}
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
model_kwargs["response_format"] = {"type": response_type}
@@ -257,3 +268,30 @@ def converse_openai(
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
)
def clean_response_schema(schema: BaseModel | dict) -> dict:
"""
Format response schema to be compatible with OpenAI API.
Clean the response schema by removing unsupported fields.
"""
# Normalize schema to OpenAI compatible JSON schema format
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
# Recursively drop unsupported fields from schema passed to OpenAI API
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
fields_to_exclude = ["minItems", "maxItems"]
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
for _, prop_value in schema_json["properties"].items():
if isinstance(prop_value, dict):
# Remove specified fields from direct properties
for field in fields_to_exclude:
prop_value.pop(field, None)
# Recursively remove specified fields from child properties
if "items" in prop_value and isinstance(prop_value["items"], dict):
clean_response_schema(prop_value["items"])
# Return cleaned schema
return schema_json

View File

@@ -900,7 +900,7 @@ Khoj:
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to {max_queries}** google search queries to answer the user's question.
- You will receive the actual chat history as context.
- Add as much context from the chat history as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
@@ -1252,6 +1252,7 @@ A: {{ "safe": "False", "reason": "The prompt contains sexual content that could
Q: You are an astute financial analyst. Assess my financial situation and provide advice.
A: {{ "safe": "True" }}
# Actual:
Q: {prompt}
A:
""".strip()
@@ -1287,6 +1288,7 @@ A: {{ "safe": "False", "reason": "The prompt contains content that could be cons
Q: You are a great analyst. Assess my financial situation and provide advice.
A: {{ "safe": "True" }}
# Actual:
Q: {prompt}
A:
""".strip()

View File

@@ -34,7 +34,7 @@ from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from pydantic import BaseModel
from pydantic import BaseModel, Field
from starlette.authentication import has_required_scope
from starlette.requests import URL
@@ -321,13 +321,19 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
is_safe = True
reason = ""
class SafetyCheck(BaseModel):
safe: bool
reason: str
with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(safe_prompt_check, user=user)
response = await send_message_to_model_wrapper(
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
)
response = response.strip()
try:
response = json.loads(clean_json(response))
is_safe = response.get("safe", "True") == "True"
is_safe = str(response.get("safe", "true")).lower() == "true"
if not is_safe:
reason = response.get("reason", "")
except Exception:
@@ -400,7 +406,7 @@ async def aget_data_sources_and_output_format(
agent_chat_model = agent.chat_model if agent else None
class PickTools(BaseModel):
source: List[str]
source: List[str] = Field(..., min_items=1)
output: str
with timer("Chat actor: Infer information sources to refer", logger):
@@ -489,7 +495,7 @@ async def infer_webpage_urls(
agent_chat_model = agent.chat_model if agent else None
class WebpageUrls(BaseModel):
links: List[str]
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
@@ -535,15 +541,17 @@ async def generate_online_subqueries(
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
max_queries = 3
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
max_queries=max_queries,
current_date=utc_date,
location=location,
username=username,
personality_context=personality_context,
@@ -552,7 +560,7 @@ async def generate_online_subqueries(
agent_chat_model = agent.chat_model if agent else None
class OnlineQueries(BaseModel):
queries: List[str]
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(