diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 4d9b42d9..f66d7c68 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -9,6 +9,7 @@ from google import genai from google.genai import errors as gerrors from google.genai import types as gtypes from langchain.schema import ChatMessage +from pydantic import BaseModel from tenacity import ( before_sleep_log, retry, @@ -86,6 +87,11 @@ def gemini_completion_with_backoff( formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) + # format model response schema + response_schema = None + if model_kwargs and "response_schema" in model_kwargs: + response_schema = clean_response_schema(model_kwargs["response_schema"]) + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, @@ -93,7 +99,7 @@ def gemini_completion_with_backoff( max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, safety_settings=SAFETY_SETTINGS, response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", - response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None, + response_schema=response_schema, seed=seed, ) @@ -318,3 +324,18 @@ def format_messages_for_gemini( formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] return formatted_messages, system_prompt + + +def clean_response_schema(response_schema: BaseModel) -> dict: + """ + Convert Pydantic model to dict for Gemini response schema. + + Ensure response schema adheres to the order of the original property definition. + """ + # Convert Pydantic model to dict + response_schema_dict = response_schema.model_json_schema() + # Get field names in original definition order + field_names = list(response_schema.model_fields.keys()) + # Generate content in the order in which the schema properties were defined + response_schema_dict["property_ordering"] = field_names + return response_schema_dict