mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Use better, standard default temp, top_p for openai model providers
This commit is contained in:
@@ -85,10 +85,10 @@ async def converse_openai(
|
||||
program_execution_context: List[str] = None,
|
||||
location_data: LocationData = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model: str = "gpt-4o-mini",
|
||||
model: str = "gpt-4.1-mini",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 0.4,
|
||||
temperature: float = 0.6,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
user_name: str = None,
|
||||
|
||||
@@ -71,7 +71,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
||||
def completion_with_backoff(
|
||||
messages: List[ChatMessage],
|
||||
model_name: str,
|
||||
temperature=0.8,
|
||||
temperature=0.6,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
deepthought: bool = False,
|
||||
@@ -89,14 +89,19 @@ def completion_with_backoff(
|
||||
if stream:
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
|
||||
|
||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||
|
||||
# Tune reasoning models arguments
|
||||
if is_openai_reasoning_model(model_name, api_base_url):
|
||||
temperature = 1
|
||||
model_kwargs["temperature"] = 1
|
||||
reasoning_effort = "medium" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
model_kwargs.pop("top_p", None)
|
||||
elif is_twitter_reasoning_model(model_name, api_base_url):
|
||||
model_kwargs.pop("temperature", None)
|
||||
reasoning_effort = "high" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
elif model_name.startswith("deepseek-reasoner"):
|
||||
@@ -131,7 +136,6 @@ def completion_with_backoff(
|
||||
with client.beta.chat.completions.stream(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
timeout=httpx.Timeout(30, read=read_timeout),
|
||||
**model_kwargs,
|
||||
) as chat:
|
||||
@@ -233,9 +237,7 @@ async def chat_completion_with_backoff(
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
deepthought=False,
|
||||
model_kwargs: dict = {},
|
||||
tracer: dict = {},
|
||||
tools=None,
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client = openai_async_clients.get(client_key)
|
||||
@@ -243,6 +245,7 @@ async def chat_completion_with_backoff(
|
||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||
openai_async_clients[client_key] = client
|
||||
|
||||
model_kwargs: dict = {}
|
||||
stream = not is_non_streaming_model(model_name, api_base_url)
|
||||
stream_processor = astream_thought_processor
|
||||
if stream:
|
||||
@@ -250,6 +253,8 @@ async def chat_completion_with_backoff(
|
||||
else:
|
||||
model_kwargs.pop("stream_options", None)
|
||||
|
||||
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
|
||||
|
||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||
|
||||
# Configure thinking for openai reasoning models
|
||||
@@ -257,7 +262,9 @@ async def chat_completion_with_backoff(
|
||||
temperature = 1
|
||||
reasoning_effort = "medium" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
|
||||
# Remove unsupported params for reasoning models
|
||||
model_kwargs.pop("top_p", None)
|
||||
model_kwargs.pop("stop", None)
|
||||
|
||||
# Get the first system message and add the string `Formatting re-enabled` to it.
|
||||
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
||||
@@ -304,8 +311,6 @@ async def chat_completion_with_backoff(
|
||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
if tools:
|
||||
model_kwargs["tools"] = tools
|
||||
|
||||
aggregated_response = ""
|
||||
final_chunk = None
|
||||
|
||||
Reference in New Issue
Block a user