mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Standardize AI model response temperature across provider specific ranges
- Anthropic expects a 0-1 range. Gemini & OpenAI expect a 0-2 range - Anneal temperature to explore reasoning trajectories but respond factually - Default send_message_to_model and extract_question temps to the same
This commit is contained in:
@@ -35,7 +35,6 @@ def extract_questions_anthropic(
|
|||||||
conversation_log={},
|
conversation_log={},
|
||||||
api_key=None,
|
api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
temperature=0.7,
|
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
@@ -101,7 +100,6 @@ def extract_questions_anthropic(
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
@@ -242,7 +240,7 @@ def converse_anthropic(
|
|||||||
compiled_references=references,
|
compiled_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=0,
|
temperature=0.2,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def anthropic_completion_with_backoff(
|
|||||||
messages,
|
messages,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature=0,
|
temperature=0.4,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
api_base_url: str = None,
|
api_base_url: str = None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ def extract_questions_gemini(
|
|||||||
conversation_log={},
|
conversation_log={},
|
||||||
api_key=None,
|
api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
@@ -103,7 +102,6 @@ def extract_questions_gemini(
|
|||||||
model,
|
model,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
temperature=temperature,
|
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,7 +128,6 @@ def gemini_send_message_to_model(
|
|||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
response_type="text",
|
response_type="text",
|
||||||
response_schema=None,
|
response_schema=None,
|
||||||
temperature=0.6,
|
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
@@ -153,7 +150,6 @@ def gemini_send_message_to_model(
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
temperature=temperature,
|
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
@@ -168,7 +164,7 @@ def converse_gemini(
|
|||||||
model: Optional[str] = "gemini-2.0-flash",
|
model: Optional[str] = "gemini-2.0-flash",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.4,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
|||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def gemini_completion_with_backoff(
|
def gemini_completion_with_backoff(
|
||||||
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
messages, system_prompt, model_name, temperature=0.8, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
||||||
) -> str:
|
) -> str:
|
||||||
client = gemini_clients.get(api_key)
|
client = gemini_clients.get(api_key)
|
||||||
if not client:
|
if not client:
|
||||||
|
|||||||
@@ -63,7 +63,6 @@ def extract_questions(
|
|||||||
today = datetime.today()
|
today = datetime.today()
|
||||||
current_new_year = today.replace(month=1, day=1)
|
current_new_year = today.replace(month=1, day=1)
|
||||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||||
temperature = 0.7
|
|
||||||
|
|
||||||
prompt = prompts.extract_questions.format(
|
prompt = prompts.extract_questions.format(
|
||||||
current_date=today.strftime("%Y-%m-%d"),
|
current_date=today.strftime("%Y-%m-%d"),
|
||||||
@@ -99,7 +98,6 @@ def extract_questions(
|
|||||||
model,
|
model,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
temperature=temperature,
|
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,7 +125,6 @@ def send_message_to_model(
|
|||||||
response_type="text",
|
response_type="text",
|
||||||
response_schema=None,
|
response_schema=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
temperature=0,
|
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -146,7 +143,6 @@ def send_message_to_model(
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
temperature=temperature,
|
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
@@ -162,7 +158,7 @@ def converse_openai(
|
|||||||
model: str = "gpt-4o-mini",
|
model: str = "gpt-4o-mini",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.4,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
|
|||||||
def completion_with_backoff(
|
def completion_with_backoff(
|
||||||
messages,
|
messages,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature=0,
|
temperature=0.8,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
|
|||||||
Reference in New Issue
Block a user