mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Support enforcing json schema in supported AI model APIs (#1133)
- Trigger Gemini 2.0 Flash doesn't always follow JSON schema in research prompt - Details - Use json schema to enforce generate online queries format - Use json schema to enforce research mode tool pick format - Support constraining Gemini model output to specified response schema - Support constraining OpenAI model output to specified response schema - Only enforce json output in supported AI model APIs - Simplify OpenAI reasoning model specific arguments to OpenAI API
This commit is contained in:
@@ -121,6 +121,7 @@ def gemini_send_message_to_model(
|
||||
api_key,
|
||||
model,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
temperature=0.6,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
@@ -135,6 +136,7 @@ def gemini_send_message_to_model(
|
||||
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
|
||||
if response_type == "json_object" and model in ["gemini-2.0-flash"]:
|
||||
model_kwargs["response_mime_type"] = "application/json"
|
||||
model_kwargs["response_schema"] = response_schema
|
||||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
|
||||
@@ -66,6 +66,7 @@ def gemini_completion_with_backoff(
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
|
||||
)
|
||||
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
|
||||
@@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
completion_with_backoff,
|
||||
get_openai_api_json_support,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
@@ -119,12 +121,26 @@ def extract_questions(
|
||||
|
||||
|
||||
def send_message_to_model(
|
||||
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
api_base_url=None,
|
||||
temperature=0,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
||||
model_kwargs = {}
|
||||
json_support = get_openai_api_json_support(model, api_base_url)
|
||||
if response_schema and json_support == JsonSupport.SCHEMA:
|
||||
model_kwargs["response_format"] = response_schema
|
||||
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
|
||||
model_kwargs["response_format"] = {"type": response_type}
|
||||
|
||||
# Get Response from GPT
|
||||
return completion_with_backoff(
|
||||
messages=messages,
|
||||
@@ -132,7 +148,7 @@ def send_message_to_model(
|
||||
openai_api_key=api_key,
|
||||
temperature=temperature,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": response_type}},
|
||||
model_kwargs=model_kwargs,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import Dict, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
@@ -16,6 +17,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
@@ -60,45 +62,29 @@ def completion_with_backoff(
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
|
||||
# Update request parameters for compatability with o1 model series
|
||||
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
||||
stream = True
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if model_name == "o1":
|
||||
temperature = 1
|
||||
stream = False
|
||||
model_kwargs.pop("stream_options", None)
|
||||
elif model_name.startswith("o1"):
|
||||
temperature = 1
|
||||
model_kwargs.pop("response_format", None)
|
||||
elif model_name.startswith("o3-"):
|
||||
# Tune reasoning models arguments
|
||||
if model_name.startswith("o1") or model_name.startswith("o3"):
|
||||
temperature = 1
|
||||
model_kwargs["reasoning_effort"] = "medium"
|
||||
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||
aggregated_response = ""
|
||||
with client.beta.chat.completions.stream(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name, # type: ignore
|
||||
stream=stream,
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
if not stream:
|
||||
chunk = chat
|
||||
aggregated_response = chunk.choices[0].message.content
|
||||
else:
|
||||
) as chat:
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
if chunk.type == "error":
|
||||
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta # type: ignore
|
||||
if isinstance(delta_chunk, str):
|
||||
aggregated_response += delta_chunk
|
||||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
elif chunk.type == "content.delta":
|
||||
aggregated_response += chunk.delta
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
@@ -172,20 +158,13 @@ def llm_thread(
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
|
||||
# Update request parameters for compatability with o1 model series
|
||||
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
||||
stream = True
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if model_name == "o1":
|
||||
# Tune reasoning models arguments
|
||||
if model_name.startswith("o1"):
|
||||
temperature = 1
|
||||
stream = False
|
||||
model_kwargs.pop("stream_options", None)
|
||||
elif model_name.startswith("o1-"):
|
||||
elif model_name.startswith("o3"):
|
||||
temperature = 1
|
||||
model_kwargs.pop("response_format", None)
|
||||
elif model_name.startswith("o3-"):
|
||||
temperature = 1
|
||||
# Get the first system message and add the string `Formatting re-enabled` to it. See https://platform.openai.com/docs/guides/reasoning-best-practices
|
||||
# Get the first system message and add the string `Formatting re-enabled` to it.
|
||||
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
||||
if len(formatted_messages) > 0:
|
||||
system_messages = [
|
||||
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
|
||||
@@ -195,7 +174,6 @@ def llm_thread(
|
||||
formatted_messages[first_system_message_index][
|
||||
"content"
|
||||
] = f"{first_system_message} Formatting re-enabled"
|
||||
|
||||
elif model_name.startswith("deepseek-reasoner"):
|
||||
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
||||
# The first message should always be a user message (except system message).
|
||||
@@ -210,6 +188,8 @@ def llm_thread(
|
||||
|
||||
formatted_messages = updated_messages
|
||||
|
||||
stream = True
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
@@ -258,3 +238,13 @@ def llm_thread(
|
||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
g.close()
|
||||
|
||||
|
||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||
if model_name.startswith("deepseek-reasoner"):
|
||||
return JsonSupport.NONE
|
||||
if api_base_url:
|
||||
host = urlparse(api_base_url).hostname
|
||||
if host and host.endswith(".ai.azure.com"):
|
||||
return JsonSupport.OBJECT
|
||||
return JsonSupport.SCHEMA
|
||||
|
||||
@@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
|
||||
return str(content)
|
||||
|
||||
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
|
||||
|
||||
|
||||
class JsonSupport(int, Enum):
|
||||
NONE = 0
|
||||
OBJECT = 1
|
||||
SCHEMA = 2
|
||||
|
||||
@@ -540,11 +540,15 @@ async def generate_online_subqueries(
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class OnlineQueries(BaseModel):
|
||||
queries: List[str]
|
||||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
response_schema=OnlineQueries,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
@@ -1129,6 +1133,7 @@ async def send_message_to_model_wrapper(
|
||||
query: str,
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel = None,
|
||||
deepthought: bool = False,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
@@ -1209,6 +1214,7 @@ async def send_message_to_model_wrapper(
|
||||
api_key=api_key,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1255,6 +1261,7 @@ async def send_message_to_model_wrapper(
|
||||
api_key=api_key,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
@@ -1265,6 +1272,7 @@ def send_message_to_model_wrapper_sync(
|
||||
message: str,
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel = None,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = "",
|
||||
@@ -1326,6 +1334,7 @@ def send_message_to_model_wrapper_sync(
|
||||
api_base_url=api_base_url,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1370,6 +1379,7 @@ def send_message_to_model_wrapper_sync(
|
||||
api_key=api_key,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional, Type
|
||||
|
||||
import yaml
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
@@ -36,6 +38,40 @@ from khoj.utils.rawconfig import LocationData
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlanningResponse(BaseModel):
|
||||
"""
|
||||
Schema for the response from planning agent when deciding the next tool to pick.
|
||||
The tool field is dynamically validated based on available tools.
|
||||
"""
|
||||
|
||||
scratchpad: str = Field(..., description="Reasoning about which tool to use next")
|
||||
query: str = Field(..., description="Detailed query for the selected tool")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]:
|
||||
"""
|
||||
Factory method that creates a customized PlanningResponse model
|
||||
with a properly typed tool field based on available tools.
|
||||
|
||||
Args:
|
||||
tool_options: Dictionary mapping tool names to values
|
||||
|
||||
Returns:
|
||||
A customized PlanningResponse class
|
||||
"""
|
||||
# Create dynamic enum from tool options
|
||||
tool_enum = Enum("ToolEnum", tool_options) # type: ignore
|
||||
|
||||
# Create and return a customized response model with the enum
|
||||
class PlanningResponseWithTool(PlanningResponse):
|
||||
tool: tool_enum = Field(..., description="Name of the tool to use")
|
||||
|
||||
return PlanningResponseWithTool
|
||||
|
||||
|
||||
async def apick_next_tool(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
@@ -61,10 +97,13 @@ async def apick_next_tool(
|
||||
# Skip showing Notes tool as an option if user has no entries
|
||||
if tool == ConversationCommand.Notes and not user_has_entries:
|
||||
continue
|
||||
tool_options[tool.value] = description
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options[tool.name] = tool.value
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
# Create planning reponse model with dynamically populated tool enum class
|
||||
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
|
||||
|
||||
# Construct chat history with user and iteration history with researcher agent for context
|
||||
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
@@ -96,6 +135,7 @@ async def apick_next_tool(
|
||||
query=query,
|
||||
context=function_planning_prompt,
|
||||
response_type="json_object",
|
||||
response_schema=planning_response_model,
|
||||
deepthought=True,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
|
||||
Reference in New Issue
Block a user