From 6980014838c76d4ff6248c2c73df8582b1622f8c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 19 Mar 2025 16:10:24 +0530 Subject: [PATCH] Support constraining Gemini model output to specified response schema If the response_schema argument is passed to send_message_to_model_wrapper it is used to constrain output by Gemini models --- src/khoj/processor/conversation/google/gemini_chat.py | 2 ++ src/khoj/processor/conversation/google/utils.py | 1 + src/khoj/routers/helpers.py | 4 ++++ 3 files changed, 7 insertions(+) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 77cff325..7f18b079 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -121,6 +121,7 @@ def gemini_send_message_to_model( api_key, model, response_type="text", + response_schema=None, temperature=0.6, model_kwargs=None, tracer={}, @@ -135,6 +136,7 @@ def gemini_send_message_to_model( # This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series. if response_type == "json_object" and model in ["gemini-2.0-flash"]: model_kwargs["response_mime_type"] = "application/json" + model_kwargs["response_schema"] = response_schema # Get Response from Gemini return gemini_completion_with_backoff( diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index ebe91527..b1a5fe77 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -66,6 +66,7 @@ def gemini_completion_with_backoff( max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, safety_settings=SAFETY_SETTINGS, response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", + response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None, ) formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4f3a0fcc..a339642a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1129,6 +1129,7 @@ async def send_message_to_model_wrapper( query: str, system_message: str = "", response_type: str = "text", + response_schema: BaseModel = None, deepthought: bool = False, user: KhojUser = None, query_images: List[str] = None, @@ -1256,6 +1257,7 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, ) else: @@ -1266,6 +1268,7 @@ def send_message_to_model_wrapper_sync( message: str, system_message: str = "", response_type: str = "text", + response_schema: BaseModel = None, user: KhojUser = None, query_images: List[str] = None, query_files: str = "", @@ -1372,6 +1375,7 @@ def send_message_to_model_wrapper_sync( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, ) else: