mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Support constraining Gemini model output to specified response schema
If the response_schema argument is passed to send_message_to_model_wrapper it is used to constrain output by Gemini models
This commit is contained in:
@@ -121,6 +121,7 @@ def gemini_send_message_to_model(
|
|||||||
api_key,
|
api_key,
|
||||||
model,
|
model,
|
||||||
response_type="text",
|
response_type="text",
|
||||||
|
response_schema=None,
|
||||||
temperature=0.6,
|
temperature=0.6,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tracer={},
|
tracer={},
|
||||||
@@ -135,6 +136,7 @@ def gemini_send_message_to_model(
|
|||||||
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
|
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
|
||||||
if response_type == "json_object" and model in ["gemini-2.0-flash"]:
|
if response_type == "json_object" and model in ["gemini-2.0-flash"]:
|
||||||
model_kwargs["response_mime_type"] = "application/json"
|
model_kwargs["response_mime_type"] = "application/json"
|
||||||
|
model_kwargs["response_schema"] = response_schema
|
||||||
|
|
||||||
# Get Response from Gemini
|
# Get Response from Gemini
|
||||||
return gemini_completion_with_backoff(
|
return gemini_completion_with_backoff(
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ def gemini_completion_with_backoff(
|
|||||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||||
safety_settings=SAFETY_SETTINGS,
|
safety_settings=SAFETY_SETTINGS,
|
||||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||||
|
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||||
|
|||||||
@@ -1129,6 +1129,7 @@ async def send_message_to_model_wrapper(
|
|||||||
query: str,
|
query: str,
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
|
response_schema: BaseModel = None,
|
||||||
deepthought: bool = False,
|
deepthought: bool = False,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
@@ -1256,6 +1257,7 @@ async def send_message_to_model_wrapper(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
response_schema=response_schema,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1266,6 +1268,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
message: str,
|
message: str,
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
|
response_schema: BaseModel = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
query_files: str = "",
|
query_files: str = "",
|
||||||
@@ -1372,6 +1375,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
response_schema=response_schema,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user