From aab010723cf27895681ceca0fc95de1539a5ee95 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 3 Apr 2025 00:04:57 +0530 Subject: [PATCH] Make Gemini response adhere to the order of the schema property definitions Without explicitly using the property ordering field, gemini returns responses in alphabetically sorted property order. We want the model to respect the schema property definition order. This ensures control during development to maintain response quality. For example in CoT make it fill scratchpad before answers. --- .../processor/conversation/google/utils.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 4d9b42d9..f66d7c68 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -9,6 +9,7 @@ from google import genai from google.genai import errors as gerrors from google.genai import types as gtypes from langchain.schema import ChatMessage +from pydantic import BaseModel from tenacity import ( before_sleep_log, retry, @@ -86,6 +87,11 @@ def gemini_completion_with_backoff( formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) + # format model response schema + response_schema = None + if model_kwargs and "response_schema" in model_kwargs: + response_schema = clean_response_schema(model_kwargs["response_schema"]) + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, @@ -93,7 +99,7 @@ def gemini_completion_with_backoff( max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, safety_settings=SAFETY_SETTINGS, response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", - response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None, + response_schema=response_schema, seed=seed, ) @@ -318,3 +324,18 @@ def format_messages_for_gemini( formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] return formatted_messages, system_prompt + + +def clean_response_schema(response_schema: BaseModel) -> dict: + """ + Convert Pydantic model to dict for Gemini response schema. + + Ensure response schema adheres to the order of the original property definition. + """ + # Convert Pydantic model to dict + response_schema_dict = response_schema.model_json_schema() + # Get field names in original definition order + field_names = list(response_schema.model_fields.keys()) + # Generate content in the order in which the schema properties were defined + response_schema_dict["property_ordering"] = field_names + return response_schema_dict