From d1df9586ca3f0ae6e7abbfc06e95b83494c18cdc Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 23 Mar 2025 18:04:21 +0530 Subject: [PATCH] Standardize AI model response temperature across provider specific ranges - Anthropic expects a 0-1 range. Gemini & OpenAI expect a 0-2 range - Anneal temperature to explore reasoning trajectories but respond factually - Default send_message_to_model and extract_question temps to the same --- src/khoj/processor/conversation/anthropic/anthropic_chat.py | 4 +--- src/khoj/processor/conversation/anthropic/utils.py | 2 +- src/khoj/processor/conversation/google/gemini_chat.py | 6 +----- src/khoj/processor/conversation/google/utils.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 6 +----- src/khoj/processor/conversation/openai/utils.py | 2 +- 6 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 01de4b16..b6aad2c0 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -35,7 +35,6 @@ def extract_questions_anthropic( conversation_log={}, api_key=None, api_base_url=None, - temperature=0.7, location_data: LocationData = None, user: KhojUser = None, query_images: Optional[list[str]] = None, @@ -101,7 +100,6 @@ def extract_questions_anthropic( messages=messages, system_prompt=system_prompt, model_name=model, - temperature=temperature, api_key=api_key, api_base_url=api_base_url, response_type="json_object", @@ -242,7 +240,7 @@ def converse_anthropic( compiled_references=references, online_results=online_results, model_name=model, - temperature=0, + temperature=0.2, api_key=api_key, api_base_url=api_base_url, system_prompt=system_prompt, diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index fac99e04..986724be 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -56,7 +56,7 @@ def anthropic_completion_with_backoff( messages, system_prompt, model_name: str, - temperature=0, + temperature=0.4, api_key=None, api_base_url: str = None, model_kwargs=None, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index f8df542f..6f518c04 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -35,7 +35,6 @@ def extract_questions_gemini( conversation_log={}, api_key=None, api_base_url=None, - temperature=0.6, max_tokens=None, location_data: LocationData = None, user: KhojUser = None, @@ -103,7 +102,6 @@ def extract_questions_gemini( model, api_base_url=api_base_url, response_type="json_object", - temperature=temperature, tracer=tracer, ) @@ -130,7 +128,6 @@ def gemini_send_message_to_model( api_base_url=None, response_type="text", response_schema=None, - temperature=0.6, model_kwargs=None, tracer={}, ): @@ -153,7 +150,6 @@ def gemini_send_message_to_model( model_name=model, api_key=api_key, api_base_url=api_base_url, - temperature=temperature, model_kwargs=model_kwargs, tracer=tracer, ) @@ -168,7 +164,7 @@ def converse_gemini( model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, api_base_url: Optional[str] = None, - temperature: float = 0.6, + temperature: float = 0.4, completion_func=None, conversation_commands=[ConversationCommand.Default], max_prompt_size=None, diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index b3bdd5a3..ff141c75 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -73,7 +73,7 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client: reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={} + messages, system_prompt, model_name, temperature=0.8, api_key=None, api_base_url=None, model_kwargs=None, tracer={} ) -> str: client = gemini_clients.get(api_key) if not client: diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f087fc93..eea83c45 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -63,7 +63,6 @@ def extract_questions( today = datetime.today() current_new_year = today.replace(month=1, day=1) last_new_year = current_new_year.replace(year=today.year - 1) - temperature = 0.7 prompt = prompts.extract_questions.format( current_date=today.strftime("%Y-%m-%d"), @@ -99,7 +98,6 @@ def extract_questions( model, response_type="json_object", api_base_url=api_base_url, - temperature=temperature, tracer=tracer, ) @@ -127,7 +125,6 @@ def send_message_to_model( response_type="text", response_schema=None, api_base_url=None, - temperature=0, tracer: dict = {}, ): """ @@ -146,7 +143,6 @@ def send_message_to_model( messages=messages, model_name=model, openai_api_key=api_key, - temperature=temperature, api_base_url=api_base_url, model_kwargs=model_kwargs, tracer=tracer, @@ -162,7 +158,7 @@ def converse_openai( model: str = "gpt-4o-mini", api_key: Optional[str] = None, api_base_url: Optional[str] = None, - temperature: float = 0.2, + temperature: float = 0.4, completion_func=None, conversation_commands=[ConversationCommand.Default], max_prompt_size=None, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index c664d882..3037270e 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -48,7 +48,7 @@ openai_clients: Dict[str, openai.OpenAI] = {} def completion_with_backoff( messages, model_name: str, - temperature=0, + temperature=0.8, openai_api_key=None, api_base_url=None, model_kwargs: dict = {},