mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Remove unsupported tool schema fields minimum, maximum for groq api
Groq API has stopped support minimum and maximum items fields from tool schema. This unexpectedly broke using AI models served via Groq API like Kimi K2 and GPT-OSS in research mode. Improve typing of relevant fields
This commit is contained in:
@@ -27,12 +27,12 @@ logger = logging.getLogger(__name__)
|
|||||||
def send_message_to_model(
|
def send_message_to_model(
|
||||||
messages,
|
messages,
|
||||||
api_key,
|
api_key,
|
||||||
model,
|
model: str,
|
||||||
response_type="text",
|
response_type="text",
|
||||||
response_schema=None,
|
response_schema=None,
|
||||||
tools: list[ToolDefinition] = None,
|
tools: list[ToolDefinition] = None,
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
api_base_url=None,
|
api_base_url: str | None = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -43,9 +43,7 @@ def send_message_to_model(
|
|||||||
json_support = get_structured_output_support(model, api_base_url)
|
json_support = get_structured_output_support(model, api_base_url)
|
||||||
strict = not is_cerebras_api(api_base_url)
|
strict = not is_cerebras_api(api_base_url)
|
||||||
if tools and json_support == StructuredOutputSupport.TOOL:
|
if tools and json_support == StructuredOutputSupport.TOOL:
|
||||||
model_kwargs["tools"] = to_openai_tools(
|
model_kwargs["tools"] = to_openai_tools(tools, model=model, api_base_url=api_base_url)
|
||||||
tools, use_responses_api=supports_responses_api(model, api_base_url), strict=strict
|
|
||||||
)
|
|
||||||
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
||||||
# Drop unsupported fields from schema passed to OpenAI APi
|
# Drop unsupported fields from schema passed to OpenAI APi
|
||||||
cleaned_response_schema = clean_response_schema(response_schema)
|
cleaned_response_schema = clean_response_schema(response_schema)
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ def completion_with_backoff(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
temperature=0.6,
|
temperature=0.6,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
api_base_url=None,
|
api_base_url: str | None = None,
|
||||||
deepthought: bool = False,
|
deepthought: bool = False,
|
||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
@@ -882,21 +882,21 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_cerebras_api(api_base_url: str = None) -> bool:
|
def is_cerebras_api(api_base_url: str | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the model is served over the Cerebras API
|
Check if the model is served over the Cerebras API
|
||||||
"""
|
"""
|
||||||
return api_base_url is not None and api_base_url.startswith("https://api.cerebras.ai/v1")
|
return api_base_url is not None and api_base_url.startswith("https://api.cerebras.ai/v1")
|
||||||
|
|
||||||
|
|
||||||
def is_groq_api(api_base_url: str = None) -> bool:
|
def is_groq_api(api_base_url: str | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the model is served over the Groq API
|
Check if the model is served over the Groq API
|
||||||
"""
|
"""
|
||||||
return api_base_url is not None and api_base_url.startswith("https://api.groq.com")
|
return api_base_url is not None and api_base_url.startswith("https://api.groq.com")
|
||||||
|
|
||||||
|
|
||||||
def is_qwen_style_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
def is_qwen_style_reasoning_model(model_name: str, api_base_url: str | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the model is a Qwen style reasoning model
|
Check if the model is a Qwen style reasoning model
|
||||||
"""
|
"""
|
||||||
@@ -1225,15 +1225,18 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool, strict: bool) -> List[Dict] | None:
|
def to_openai_tools(tools: List[ToolDefinition], model: str, api_base_url: str | None = None) -> List[Dict] | None:
|
||||||
"Transform tool definitions from standard format to OpenAI format."
|
"Transform tool definitions from standard format to OpenAI format."
|
||||||
|
use_responses_api = supports_responses_api(model, api_base_url)
|
||||||
|
strict = not is_cerebras_api(api_base_url)
|
||||||
|
fields_to_exclude = ["minimum", "maximum"] if is_groq_api(api_base_url) else []
|
||||||
if use_responses_api:
|
if use_responses_api:
|
||||||
openai_tools = [
|
openai_tools = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": clean_response_schema(tool.schema),
|
"parameters": clean_response_schema(tool.schema, fields_to_exclude=fields_to_exclude),
|
||||||
"strict": strict,
|
"strict": strict,
|
||||||
}
|
}
|
||||||
for tool in tools
|
for tool in tools
|
||||||
@@ -1245,7 +1248,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool, strict
|
|||||||
"function": {
|
"function": {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": clean_response_schema(tool.schema),
|
"parameters": clean_response_schema(tool.schema, fields_to_exclude=fields_to_exclude),
|
||||||
"strict": strict,
|
"strict": strict,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1255,7 +1258,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool, strict
|
|||||||
return openai_tools or None
|
return openai_tools or None
|
||||||
|
|
||||||
|
|
||||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
def clean_response_schema(schema: BaseModel | dict, fields_to_exclude: list[str] = []) -> dict:
|
||||||
"""
|
"""
|
||||||
Format response schema to be compatible with OpenAI API.
|
Format response schema to be compatible with OpenAI API.
|
||||||
|
|
||||||
@@ -1267,7 +1270,7 @@ def clean_response_schema(schema: BaseModel | dict) -> dict:
|
|||||||
|
|
||||||
# Recursively drop unsupported fields from schema passed to OpenAI API
|
# Recursively drop unsupported fields from schema passed to OpenAI API
|
||||||
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
||||||
fields_to_exclude = ["minItems", "maxItems"]
|
fields_to_exclude += ["minItems", "maxItems"]
|
||||||
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
|
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
|
||||||
for _, prop_value in schema_json["properties"].items():
|
for _, prop_value in schema_json["properties"].items():
|
||||||
if isinstance(prop_value, dict):
|
if isinstance(prop_value, dict):
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ async def apick_next_tool(
|
|||||||
# Construct tool options for the agent to choose from
|
# Construct tool options for the agent to choose from
|
||||||
tools = []
|
tools = []
|
||||||
tool_options_str = ""
|
tool_options_str = ""
|
||||||
agent_input_tools = agent.input_tools if agent else []
|
agent_input_tools = agent.input_tools if agent and agent.input_tools else []
|
||||||
agent_tools = []
|
agent_tools = []
|
||||||
|
|
||||||
# Map agent user facing tools to research tools to include in agents toolbox
|
# Map agent user facing tools to research tools to include in agents toolbox
|
||||||
|
|||||||
@@ -664,13 +664,13 @@ tools_for_research_llm = {
|
|||||||
},
|
},
|
||||||
"lines_before": {
|
"lines_before": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Optional number of lines to show before each line match for context.",
|
"description": "Optional number of lines to show before each line match for context. It should be a positive number between 0 and 20.",
|
||||||
"minimum": 0,
|
"minimum": 0,
|
||||||
"maximum": 20,
|
"maximum": 20,
|
||||||
},
|
},
|
||||||
"lines_after": {
|
"lines_after": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Optional number of lines to show after each line match for context.",
|
"description": "Optional number of lines to show after each line match for context. It should be a positive number between 0 and 20.",
|
||||||
"minimum": 0,
|
"minimum": 0,
|
||||||
"maximum": 20,
|
"maximum": 20,
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user