mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Improve Gemini Response Reliability (#1148)
- Specify min, max number of list items expected in AI response via JSON schema enforcement. Used by Gemini models - Warn and drop invalid/empty messages when format messages for Gemini models - Make Gemini response adhere to the order of the schema property definitions - Improve agent creation safety checker by using response schema, better prompt
This commit is contained in:
@@ -4,12 +4,11 @@ from typing import Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
@@ -98,7 +97,7 @@ def extract_questions_gemini(
|
||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||
|
||||
class DocumentQueries(BaseModel):
|
||||
queries: List[str]
|
||||
queries: List[str] = Field(..., min_items=1)
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages,
|
||||
|
||||
@@ -9,6 +9,7 @@ from google import genai
|
||||
from google.genai import errors as gerrors
|
||||
from google.genai import types as gtypes
|
||||
from langchain.schema import ChatMessage
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@@ -86,6 +87,11 @@ def gemini_completion_with_backoff(
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
# format model response schema
|
||||
response_schema = None
|
||||
if model_kwargs and "response_schema" in model_kwargs:
|
||||
response_schema = clean_response_schema(model_kwargs["response_schema"])
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
@@ -93,7 +99,7 @@ def gemini_completion_with_backoff(
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
|
||||
response_schema=response_schema,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
@@ -294,11 +300,21 @@ def format_messages_for_gemini(
|
||||
else:
|
||||
image = get_image_from_base64(image_data, type="bytes")
|
||||
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
|
||||
elif not is_none_or_empty(item.get("text")):
|
||||
message_content += [gtypes.Part.from_text(text=item["text"])]
|
||||
else:
|
||||
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]
|
||||
logger.error(f"Dropping invalid message content part: {item}")
|
||||
if not message_content:
|
||||
logger.error(f"Dropping empty message content")
|
||||
messages.remove(message)
|
||||
continue
|
||||
message.content = message_content
|
||||
elif isinstance(message.content, str):
|
||||
message.content = [gtypes.Part.from_text(text=message.content)]
|
||||
else:
|
||||
logger.error(f"Dropping invalid type: {type(message.content)} of message content: {message.content}")
|
||||
messages.remove(message)
|
||||
continue
|
||||
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
@@ -308,3 +324,18 @@ def format_messages_for_gemini(
|
||||
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
return formatted_messages, system_prompt
|
||||
|
||||
|
||||
def clean_response_schema(response_schema: BaseModel) -> dict:
|
||||
"""
|
||||
Convert Pydantic model to dict for Gemini response schema.
|
||||
|
||||
Ensure response schema adheres to the order of the original property definition.
|
||||
"""
|
||||
# Convert Pydantic model to dict
|
||||
response_schema_dict = response_schema.model_json_schema()
|
||||
# Get field names in original definition order
|
||||
field_names = list(response_schema.model_fields.keys())
|
||||
# Generate content in the order in which the schema properties were defined
|
||||
response_schema_dict["property_ordering"] = field_names
|
||||
return response_schema_dict
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
from openai.lib._pydantic import _ensure_strict_json_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
@@ -135,7 +137,16 @@ def send_message_to_model(
|
||||
model_kwargs = {}
|
||||
json_support = get_openai_api_json_support(model, api_base_url)
|
||||
if response_schema and json_support == JsonSupport.SCHEMA:
|
||||
model_kwargs["response_format"] = response_schema
|
||||
# Drop unsupported fields from schema passed to OpenAI APi
|
||||
cleaned_response_schema = clean_response_schema(response_schema)
|
||||
model_kwargs["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"schema": cleaned_response_schema,
|
||||
"name": response_schema.__name__,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
|
||||
model_kwargs["response_format"] = {"type": response_type}
|
||||
|
||||
@@ -257,3 +268,30 @@ def converse_openai(
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||
"""
|
||||
Format response schema to be compatible with OpenAI API.
|
||||
|
||||
Clean the response schema by removing unsupported fields.
|
||||
"""
|
||||
# Normalize schema to OpenAI compatible JSON schema format
|
||||
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
|
||||
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
|
||||
|
||||
# Recursively drop unsupported fields from schema passed to OpenAI API
|
||||
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
||||
fields_to_exclude = ["minItems", "maxItems"]
|
||||
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
|
||||
for _, prop_value in schema_json["properties"].items():
|
||||
if isinstance(prop_value, dict):
|
||||
# Remove specified fields from direct properties
|
||||
for field in fields_to_exclude:
|
||||
prop_value.pop(field, None)
|
||||
# Recursively remove specified fields from child properties
|
||||
if "items" in prop_value and isinstance(prop_value["items"], dict):
|
||||
clean_response_schema(prop_value["items"])
|
||||
|
||||
# Return cleaned schema
|
||||
return schema_json
|
||||
|
||||
@@ -900,7 +900,7 @@ Khoj:
|
||||
|
||||
online_search_conversation_subqueries = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
|
||||
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to {max_queries}** google search queries to answer the user's question.
|
||||
- You will receive the actual chat history as context.
|
||||
- Add as much context from the chat history as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
@@ -1252,6 +1252,7 @@ A: {{ "safe": "False", "reason": "The prompt contains sexual content that could
|
||||
Q: You are an astute financial analyst. Assess my financial situation and provide advice.
|
||||
A: {{ "safe": "True" }}
|
||||
|
||||
# Actual:
|
||||
Q: {prompt}
|
||||
A:
|
||||
""".strip()
|
||||
@@ -1287,6 +1288,7 @@ A: {{ "safe": "False", "reason": "The prompt contains content that could be cons
|
||||
Q: You are a great analyst. Assess my financial situation and provide advice.
|
||||
A: {{ "safe": "True" }}
|
||||
|
||||
# Actual:
|
||||
Q: {prompt}
|
||||
A:
|
||||
""".strip()
|
||||
|
||||
@@ -34,7 +34,7 @@ from apscheduler.job import Job
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
|
||||
@@ -321,13 +321,19 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
|
||||
is_safe = True
|
||||
reason = ""
|
||||
|
||||
class SafetyCheck(BaseModel):
|
||||
safe: bool
|
||||
reason: str
|
||||
|
||||
with timer("Chat actor: Check if safe prompt", logger):
|
||||
response = await send_message_to_model_wrapper(safe_prompt_check, user=user)
|
||||
response = await send_message_to_model_wrapper(
|
||||
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
|
||||
)
|
||||
|
||||
response = response.strip()
|
||||
try:
|
||||
response = json.loads(clean_json(response))
|
||||
is_safe = response.get("safe", "True") == "True"
|
||||
is_safe = str(response.get("safe", "true")).lower() == "true"
|
||||
if not is_safe:
|
||||
reason = response.get("reason", "")
|
||||
except Exception:
|
||||
@@ -400,7 +406,7 @@ async def aget_data_sources_and_output_format(
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class PickTools(BaseModel):
|
||||
source: List[str]
|
||||
source: List[str] = Field(..., min_items=1)
|
||||
output: str
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
@@ -489,7 +495,7 @@ async def infer_webpage_urls(
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class WebpageUrls(BaseModel):
|
||||
links: List[str]
|
||||
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
|
||||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
@@ -535,15 +541,17 @@ async def generate_online_subqueries(
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
max_queries = 3
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
max_queries=max_queries,
|
||||
current_date=utc_date,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
@@ -552,7 +560,7 @@ async def generate_online_subqueries(
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class OnlineQueries(BaseModel):
|
||||
queries: List[str]
|
||||
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
|
||||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
|
||||
Reference in New Issue
Block a user