mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 05:40:17 +00:00
Simplify OpenAI reasoning model specific arguments to OpenAI API
Previously OpenAI reasoning models didn't support stream_options and response_format Add reasoning_effort arg for calls to OpenAI reasoning models via API. Right now it defaults to medium but can be changed to low or high
This commit is contained in:
@@ -60,20 +60,13 @@ def completion_with_backoff(
|
|||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
|
||||||
# Update request parameters for compatability with o1 model series
|
# Tune reasoning models arguments
|
||||||
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
if model_name.startswith("o1") or model_name.startswith("o3"):
|
||||||
|
temperature = 1
|
||||||
|
model_kwargs["reasoning_effort"] = "medium"
|
||||||
|
|
||||||
stream = True
|
stream = True
|
||||||
model_kwargs["stream_options"] = {"include_usage": True}
|
model_kwargs["stream_options"] = {"include_usage": True}
|
||||||
if model_name == "o1":
|
|
||||||
temperature = 1
|
|
||||||
stream = False
|
|
||||||
model_kwargs.pop("stream_options", None)
|
|
||||||
elif model_name.startswith("o1"):
|
|
||||||
temperature = 1
|
|
||||||
model_kwargs.pop("response_format", None)
|
|
||||||
elif model_name.startswith("o3-"):
|
|
||||||
temperature = 1
|
|
||||||
|
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
@@ -172,20 +165,13 @@ def llm_thread(
|
|||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
|
||||||
# Update request parameters for compatability with o1 model series
|
# Tune reasoning models arguments
|
||||||
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
if model_name.startswith("o1"):
|
||||||
stream = True
|
|
||||||
model_kwargs["stream_options"] = {"include_usage": True}
|
|
||||||
if model_name == "o1":
|
|
||||||
temperature = 1
|
temperature = 1
|
||||||
stream = False
|
elif model_name.startswith("o3"):
|
||||||
model_kwargs.pop("stream_options", None)
|
|
||||||
elif model_name.startswith("o1-"):
|
|
||||||
temperature = 1
|
temperature = 1
|
||||||
model_kwargs.pop("response_format", None)
|
# Get the first system message and add the string `Formatting re-enabled` to it.
|
||||||
elif model_name.startswith("o3-"):
|
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
||||||
temperature = 1
|
|
||||||
# Get the first system message and add the string `Formatting re-enabled` to it. See https://platform.openai.com/docs/guides/reasoning-best-practices
|
|
||||||
if len(formatted_messages) > 0:
|
if len(formatted_messages) > 0:
|
||||||
system_messages = [
|
system_messages = [
|
||||||
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
|
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
|
||||||
@@ -195,7 +181,6 @@ def llm_thread(
|
|||||||
formatted_messages[first_system_message_index][
|
formatted_messages[first_system_message_index][
|
||||||
"content"
|
"content"
|
||||||
] = f"{first_system_message} Formatting re-enabled"
|
] = f"{first_system_message} Formatting re-enabled"
|
||||||
|
|
||||||
elif model_name.startswith("deepseek-reasoner"):
|
elif model_name.startswith("deepseek-reasoner"):
|
||||||
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
||||||
# The first message should always be a user message (except system message).
|
# The first message should always be a user message (except system message).
|
||||||
@@ -210,6 +195,8 @@ def llm_thread(
|
|||||||
|
|
||||||
formatted_messages = updated_messages
|
formatted_messages = updated_messages
|
||||||
|
|
||||||
|
stream = True
|
||||||
|
model_kwargs["stream_options"] = {"include_usage": True}
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user