mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Support constraining Gemini model output to specified response schema
If the response_schema argument is passed to send_message_to_model_wrapper it is used to constrain output by Gemini models
This commit is contained in:
@@ -121,6 +121,7 @@ def gemini_send_message_to_model(
|
||||
api_key,
|
||||
model,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
temperature=0.6,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
@@ -135,6 +136,7 @@ def gemini_send_message_to_model(
|
||||
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
|
||||
if response_type == "json_object" and model in ["gemini-2.0-flash"]:
|
||||
model_kwargs["response_mime_type"] = "application/json"
|
||||
model_kwargs["response_schema"] = response_schema
|
||||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
|
||||
@@ -66,6 +66,7 @@ def gemini_completion_with_backoff(
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
|
||||
)
|
||||
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
|
||||
@@ -1129,6 +1129,7 @@ async def send_message_to_model_wrapper(
|
||||
query: str,
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel = None,
|
||||
deepthought: bool = False,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
@@ -1256,6 +1257,7 @@ async def send_message_to_model_wrapper(
|
||||
api_key=api_key,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
@@ -1266,6 +1268,7 @@ def send_message_to_model_wrapper_sync(
|
||||
message: str,
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel = None,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = "",
|
||||
@@ -1372,6 +1375,7 @@ def send_message_to_model_wrapper_sync(
|
||||
api_key=api_key,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user