Support enforcing json schema in supported AI model APIs (#1133)

- Trigger
  Gemini 2.0 Flash doesn't always follow JSON schema in research prompt

- Details
  - Use json schema to enforce generate online queries format
  - Use json schema to enforce research mode tool pick format
  - Support constraining Gemini model output to specified response schema
  - Support constraining OpenAI model output to specified response schema
  - Only enforce json output in supported AI model APIs
  - Simplify OpenAI reasoning model specific arguments to OpenAI API
This commit is contained in:
Debanjum
2025-03-19 22:59:23 +05:30
committed by GitHub
7 changed files with 110 additions and 45 deletions

View File

@@ -121,6 +121,7 @@ def gemini_send_message_to_model(
api_key,
model,
response_type="text",
response_schema=None,
temperature=0.6,
model_kwargs=None,
tracer={},
@@ -135,6 +136,7 @@ def gemini_send_message_to_model(
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
if response_type == "json_object" and model in ["gemini-2.0-flash"]:
model_kwargs["response_mime_type"] = "application/json"
model_kwargs["response_schema"] = response_schema
# Get Response from Gemini
return gemini_completion_with_backoff(

View File

@@ -66,6 +66,7 @@ def gemini_completion_with_backoff(
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
)
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]

View File

@@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
completion_with_backoff,
get_openai_api_json_support,
)
from khoj.processor.conversation.utils import (
JsonSupport,
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
@@ -119,12 +121,26 @@ def extract_questions(
def send_message_to_model(
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
messages,
api_key,
model,
response_type="text",
response_schema=None,
api_base_url=None,
temperature=0,
tracer: dict = {},
):
"""
Send message to model
"""
model_kwargs = {}
json_support = get_openai_api_json_support(model, api_base_url)
if response_schema and json_support == JsonSupport.SCHEMA:
model_kwargs["response_format"] = response_schema
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
model_kwargs["response_format"] = {"type": response_type}
# Get Response from GPT
return completion_with_backoff(
messages=messages,
@@ -132,7 +148,7 @@ def send_message_to_model(
openai_api_key=api_key,
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}},
model_kwargs=model_kwargs,
tracer=tracer,
)

View File

@@ -2,6 +2,7 @@ import logging
import os
from threading import Thread
from typing import Dict, List
from urllib.parse import urlparse
import openai
from openai.types.chat.chat_completion import ChatCompletion
@@ -16,6 +17,7 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
ThreadedGenerator,
commit_conversation_trace,
)
@@ -60,45 +62,29 @@ def completion_with_backoff(
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
# Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
stream = True
model_kwargs["stream_options"] = {"include_usage": True}
if model_name == "o1":
temperature = 1
stream = False
model_kwargs.pop("stream_options", None)
elif model_name.startswith("o1"):
temperature = 1
model_kwargs.pop("response_format", None)
elif model_name.startswith("o3-"):
# Tune reasoning models arguments
if model_name.startswith("o1") or model_name.startswith("o3"):
temperature = 1
model_kwargs["reasoning_effort"] = "medium"
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
aggregated_response = ""
with client.beta.chat.completions.stream(
messages=formatted_messages, # type: ignore
model=model_name, # type: ignore
stream=stream,
model=model_name,
temperature=temperature,
timeout=20,
**model_kwargs,
)
aggregated_response = ""
if not stream:
chunk = chat
aggregated_response = chunk.choices[0].message.content
else:
) as chat:
for chunk in chat:
if len(chunk.choices) == 0:
if chunk.type == "error":
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
continue
delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
elif chunk.type == "content.delta":
aggregated_response += chunk.delta
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
@@ -172,20 +158,13 @@ def llm_thread(
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
# Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
stream = True
model_kwargs["stream_options"] = {"include_usage": True}
if model_name == "o1":
# Tune reasoning models arguments
if model_name.startswith("o1"):
temperature = 1
stream = False
model_kwargs.pop("stream_options", None)
elif model_name.startswith("o1-"):
elif model_name.startswith("o3"):
temperature = 1
model_kwargs.pop("response_format", None)
elif model_name.startswith("o3-"):
temperature = 1
# Get the first system message and add the string `Formatting re-enabled` to it. See https://platform.openai.com/docs/guides/reasoning-best-practices
# Get the first system message and add the string `Formatting re-enabled` to it.
# See https://platform.openai.com/docs/guides/reasoning-best-practices
if len(formatted_messages) > 0:
system_messages = [
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
@@ -195,7 +174,6 @@ def llm_thread(
formatted_messages[first_system_message_index][
"content"
] = f"{first_system_message} Formatting re-enabled"
elif model_name.startswith("deepseek-reasoner"):
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
# The first message should always be a user message (except system message).
@@ -210,6 +188,8 @@ def llm_thread(
formatted_messages = updated_messages
stream = True
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
@@ -258,3 +238,13 @@ def llm_thread(
logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally:
g.close()
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
if model_name.startswith("deepseek-reasoner"):
return JsonSupport.NONE
if api_base_url:
host = urlparse(api_base_url).hostname
if host and host.endswith(".ai.azure.com"):
return JsonSupport.OBJECT
return JsonSupport.SCHEMA

View File

@@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
return str(content)
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
class JsonSupport(int, Enum):
NONE = 0
OBJECT = 1
SCHEMA = 2

View File

@@ -540,11 +540,15 @@ async def generate_online_subqueries(
agent_chat_model = agent.chat_model if agent else None
class OnlineQueries(BaseModel):
queries: List[str]
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt,
query_images=query_images,
response_type="json_object",
response_schema=OnlineQueries,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
@@ -1129,6 +1133,7 @@ async def send_message_to_model_wrapper(
query: str,
system_message: str = "",
response_type: str = "text",
response_schema: BaseModel = None,
deepthought: bool = False,
user: KhojUser = None,
query_images: List[str] = None,
@@ -1209,6 +1214,7 @@ async def send_message_to_model_wrapper(
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
api_base_url=api_base_url,
tracer=tracer,
)
@@ -1255,6 +1261,7 @@ async def send_message_to_model_wrapper(
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tracer=tracer,
)
else:
@@ -1265,6 +1272,7 @@ def send_message_to_model_wrapper_sync(
message: str,
system_message: str = "",
response_type: str = "text",
response_schema: BaseModel = None,
user: KhojUser = None,
query_images: List[str] = None,
query_files: str = "",
@@ -1326,6 +1334,7 @@ def send_message_to_model_wrapper_sync(
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tracer=tracer,
)
@@ -1370,6 +1379,7 @@ def send_message_to_model_wrapper_sync(
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tracer=tracer,
)
else:

View File

@@ -1,10 +1,12 @@
import logging
import os
from datetime import datetime
from typing import Callable, Dict, List, Optional
from enum import Enum
from typing import Callable, Dict, List, Optional, Type
import yaml
from fastapi import Request
from pydantic import BaseModel, Field
from khoj.database.adapters import EntryAdapters
from khoj.database.models import Agent, KhojUser
@@ -36,6 +38,40 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
class PlanningResponse(BaseModel):
"""
Schema for the response from planning agent when deciding the next tool to pick.
The tool field is dynamically validated based on available tools.
"""
scratchpad: str = Field(..., description="Reasoning about which tool to use next")
query: str = Field(..., description="Detailed query for the selected tool")
class Config:
arbitrary_types_allowed = True
@classmethod
def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]:
"""
Factory method that creates a customized PlanningResponse model
with a properly typed tool field based on available tools.
Args:
tool_options: Dictionary mapping tool names to values
Returns:
A customized PlanningResponse class
"""
# Create dynamic enum from tool options
tool_enum = Enum("ToolEnum", tool_options) # type: ignore
# Create and return a customized response model with the enum
class PlanningResponseWithTool(PlanningResponse):
tool: tool_enum = Field(..., description="Name of the tool to use")
return PlanningResponseWithTool
async def apick_next_tool(
query: str,
conversation_history: dict,
@@ -61,10 +97,13 @@ async def apick_next_tool(
# Skip showing Notes tool as an option if user has no entries
if tool == ConversationCommand.Notes and not user_has_entries:
continue
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options[tool.name] = tool.value
tool_options_str += f'- "{tool.value}": "{description}"\n'
# Create planning reponse model with dynamically populated tool enum class
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
# Construct chat history with user and iteration history with researcher agent for context
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
@@ -96,6 +135,7 @@ async def apick_next_tool(
query=query,
context=function_planning_prompt,
response_type="json_object",
response_schema=planning_response_model,
deepthought=True,
user=user,
query_images=query_images,