Make Gemini response adhere to the order of the schema property definitions

Without explicitly using the property ordering field, gemini returns
responses in alphabetically sorted property order.

We want the model to respect the schema property definition order.
This ensures control during development to maintain response quality.

For example in CoT make it fill scratchpad before answers.
This commit is contained in:
Debanjum
2025-04-03 00:04:57 +05:30
parent ae9ca58ab9
commit aab010723c

View File

@@ -9,6 +9,7 @@ from google import genai
from google.genai import errors as gerrors
from google.genai import types as gtypes
from langchain.schema import ChatMessage
from pydantic import BaseModel
from tenacity import (
before_sleep_log,
retry,
@@ -86,6 +87,11 @@ def gemini_completion_with_backoff(
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
# format model response schema
response_schema = None
if model_kwargs and "response_schema" in model_kwargs:
response_schema = clean_response_schema(model_kwargs["response_schema"])
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
@@ -93,7 +99,7 @@ def gemini_completion_with_backoff(
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
response_schema=response_schema,
seed=seed,
)
@@ -318,3 +324,18 @@ def format_messages_for_gemini(
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
return formatted_messages, system_prompt
def clean_response_schema(response_schema: BaseModel) -> dict:
"""
Convert Pydantic model to dict for Gemini response schema.
Ensure response schema adheres to the order of the original property definition.
"""
# Convert Pydantic model to dict
response_schema_dict = response_schema.model_json_schema()
# Get field names in original definition order
field_names = list(response_schema.model_fields.keys())
# Generate content in the order in which the schema properties were defined
response_schema_dict["property_ordering"] = field_names
return response_schema_dict