diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index eea83c45..cc8ec027 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -124,6 +124,7 @@ def send_message_to_model( model, response_type="text", response_schema=None, + deepthought=False, api_base_url=None, tracer: dict = {}, ): @@ -144,6 +145,7 @@ def send_message_to_model( model_name=model, openai_api_key=api_key, api_base_url=api_base_url, + deepthought=deepthought, model_kwargs=model_kwargs, tracer=tracer, ) @@ -172,6 +174,7 @@ def converse_openai( generated_files: List[FileAttachment] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, + deepthought: Optional[bool] = False, tracer: dict = {}, ): """ @@ -250,6 +253,7 @@ def converse_openai( openai_api_key=api_key, api_base_url=api_base_url, completion_func=completion_func, + deepthought=deepthought, model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 7d1f114d..50e2b73e 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -51,6 +51,7 @@ def completion_with_backoff( temperature=0.8, openai_api_key=None, api_base_url=None, + deepthought: bool = False, model_kwargs: dict = {}, tracer: dict = {}, ) -> str: @@ -128,13 +129,14 @@ def chat_completion_with_backoff( openai_api_key=None, api_base_url=None, completion_func=None, + deepthought=False, model_kwargs=None, tracer: dict = {}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=llm_thread, - args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer), + args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer), ) t.start() return g @@ -147,6 +149,7 @@ def llm_thread( temperature, openai_api_key=None, api_base_url=None, + deepthought=False, model_kwargs: dict = {}, tracer: dict = {}, ): @@ -160,10 +163,11 @@ def llm_thread( formatted_messages = [{"role": message.role, "content": message.content} for message in messages] # Tune reasoning models arguments - if model_name.startswith("o1"): - temperature = 1 - elif model_name.startswith("o3"): + if model_name.startswith("o1") or model_name.startswith("o3"): temperature = 1 + model_kwargs["reasoning_effort"] = "medium" + + if model_name.startswith("o3"): # Get the first system message and add the string `Formatting re-enabled` to it. # See https://platform.openai.com/docs/guides/reasoning-best-practices if len(formatted_messages) > 0: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index cf7dd582..c7e71b09 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1215,6 +1215,7 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, response_schema=response_schema, + deepthought=deepthought, api_base_url=api_base_url, tracer=tracer, ) @@ -1511,6 +1512,7 @@ def generate_chat_response( generated_files=raw_generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, + deepthought=deepthought, tracer=tracer, )