Support constraining Gemini model output to specified response schema

If the response_schema argument is passed to
send_message_to_model_wrapper it is used to constrain output by Gemini
models
This commit is contained in:
Debanjum
2025-03-19 16:10:24 +05:30
parent ac4b36b9fd
commit 6980014838
3 changed files with 7 additions and 0 deletions

View File

@@ -121,6 +121,7 @@ def gemini_send_message_to_model(
api_key,
model,
response_type="text",
response_schema=None,
temperature=0.6,
model_kwargs=None,
tracer={},
@@ -135,6 +136,7 @@ def gemini_send_message_to_model(
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
if response_type == "json_object" and model in ["gemini-2.0-flash"]:
model_kwargs["response_mime_type"] = "application/json"
model_kwargs["response_schema"] = response_schema
# Get Response from Gemini
return gemini_completion_with_backoff(

View File

@@ -66,6 +66,7 @@ def gemini_completion_with_backoff(
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
)
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]

View File

@@ -1129,6 +1129,7 @@ async def send_message_to_model_wrapper(
query: str,
system_message: str = "",
response_type: str = "text",
response_schema: BaseModel = None,
deepthought: bool = False,
user: KhojUser = None,
query_images: List[str] = None,
@@ -1256,6 +1257,7 @@ async def send_message_to_model_wrapper(
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tracer=tracer,
)
else:
@@ -1266,6 +1268,7 @@ def send_message_to_model_wrapper_sync(
message: str,
system_message: str = "",
response_type: str = "text",
response_schema: BaseModel = None,
user: KhojUser = None,
query_images: List[str] = None,
query_files: str = "",
@@ -1372,6 +1375,7 @@ def send_message_to_model_wrapper_sync(
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tracer=tracer,
)
else: