mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Use better, standard default temp, top_p for openai model providers
This commit is contained in:
@@ -85,10 +85,10 @@ async def converse_openai(
|
|||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
chat_history: list[ChatMessageModel] = [],
|
chat_history: list[ChatMessageModel] = [],
|
||||||
model: str = "gpt-4o-mini",
|
model: str = "gpt-4.1-mini",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 0.4,
|
temperature: float = 0.6,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
|||||||
def completion_with_backoff(
|
def completion_with_backoff(
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature=0.8,
|
temperature=0.6,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
deepthought: bool = False,
|
deepthought: bool = False,
|
||||||
@@ -89,14 +89,19 @@ def completion_with_backoff(
|
|||||||
if stream:
|
if stream:
|
||||||
model_kwargs["stream_options"] = {"include_usage": True}
|
model_kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
|
model_kwargs["temperature"] = temperature
|
||||||
|
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
|
||||||
|
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||||
|
|
||||||
# Tune reasoning models arguments
|
# Tune reasoning models arguments
|
||||||
if is_openai_reasoning_model(model_name, api_base_url):
|
if is_openai_reasoning_model(model_name, api_base_url):
|
||||||
temperature = 1
|
model_kwargs["temperature"] = 1
|
||||||
reasoning_effort = "medium" if deepthought else "low"
|
reasoning_effort = "medium" if deepthought else "low"
|
||||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
model_kwargs.pop("top_p", None)
|
||||||
elif is_twitter_reasoning_model(model_name, api_base_url):
|
elif is_twitter_reasoning_model(model_name, api_base_url):
|
||||||
|
model_kwargs.pop("temperature", None)
|
||||||
reasoning_effort = "high" if deepthought else "low"
|
reasoning_effort = "high" if deepthought else "low"
|
||||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||||
elif model_name.startswith("deepseek-reasoner"):
|
elif model_name.startswith("deepseek-reasoner"):
|
||||||
@@ -131,7 +136,6 @@ def completion_with_backoff(
|
|||||||
with client.beta.chat.completions.stream(
|
with client.beta.chat.completions.stream(
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
model=model_name,
|
model=model_name,
|
||||||
temperature=temperature,
|
|
||||||
timeout=httpx.Timeout(30, read=read_timeout),
|
timeout=httpx.Timeout(30, read=read_timeout),
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) as chat:
|
) as chat:
|
||||||
@@ -233,9 +237,7 @@ async def chat_completion_with_backoff(
|
|||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs: dict = {},
|
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
tools=None,
|
|
||||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client = openai_async_clients.get(client_key)
|
client = openai_async_clients.get(client_key)
|
||||||
@@ -243,6 +245,7 @@ async def chat_completion_with_backoff(
|
|||||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||||
openai_async_clients[client_key] = client
|
openai_async_clients[client_key] = client
|
||||||
|
|
||||||
|
model_kwargs: dict = {}
|
||||||
stream = not is_non_streaming_model(model_name, api_base_url)
|
stream = not is_non_streaming_model(model_name, api_base_url)
|
||||||
stream_processor = astream_thought_processor
|
stream_processor = astream_thought_processor
|
||||||
if stream:
|
if stream:
|
||||||
@@ -250,6 +253,8 @@ async def chat_completion_with_backoff(
|
|||||||
else:
|
else:
|
||||||
model_kwargs.pop("stream_options", None)
|
model_kwargs.pop("stream_options", None)
|
||||||
|
|
||||||
|
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
|
||||||
|
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||||
|
|
||||||
# Configure thinking for openai reasoning models
|
# Configure thinking for openai reasoning models
|
||||||
@@ -257,7 +262,9 @@ async def chat_completion_with_backoff(
|
|||||||
temperature = 1
|
temperature = 1
|
||||||
reasoning_effort = "medium" if deepthought else "low"
|
reasoning_effort = "medium" if deepthought else "low"
|
||||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||||
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
|
# Remove unsupported params for reasoning models
|
||||||
|
model_kwargs.pop("top_p", None)
|
||||||
|
model_kwargs.pop("stop", None)
|
||||||
|
|
||||||
# Get the first system message and add the string `Formatting re-enabled` to it.
|
# Get the first system message and add the string `Formatting re-enabled` to it.
|
||||||
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
||||||
@@ -304,8 +311,6 @@ async def chat_completion_with_backoff(
|
|||||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
if tools:
|
|
||||||
model_kwargs["tools"] = tools
|
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
final_chunk = None
|
final_chunk = None
|
||||||
|
|||||||
Reference in New Issue
Block a user