From 272eae5d66824cc243bd177e44273b4906c02ec4 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 12 Sep 2024 15:31:11 -0700 Subject: [PATCH] Add support for the newly released OpenAI O1 model series for preview The O1 series doesn't seem to support streaming, stop words or temperature, response_format currently. --- .../processor/conversation/openai/utils.py | 45 ++++++++++++++----- src/khoj/processor/conversation/utils.py | 12 +++-- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 1a42113e..878dbb9c 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -45,15 +45,28 @@ def completion_with_backoff( openai_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + stream = True + + # Update request parameters for compatability with o1 model series + # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations + if model.startswith("o1"): + stream = False + temperature = 1 + model_kwargs.pop("stop", None) + model_kwargs.pop("response_format", None) chat = client.chat.completions.create( - stream=True, + stream=stream, messages=formatted_messages, # type: ignore model=model, # type: ignore temperature=temperature, timeout=20, **(model_kwargs or dict()), ) + + if not stream: + return chat.choices[0].message.content + aggregated_response = "" for chunk in chat: if len(chunk.choices) == 0: @@ -112,9 +125,18 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba client: openai.OpenAI = openai_clients[client_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + stream = True + + # Update request parameters for compatability with o1 model series + # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations + if model_name.startswith("o1"): + stream = False + temperature = 1 + model_kwargs.pop("stop", None) + model_kwargs.pop("response_format", None) chat = client.chat.completions.create( - stream=True, + stream=stream, messages=formatted_messages, model=model_name, # type: ignore temperature=temperature, @@ -122,14 +144,17 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba **(model_kwargs or dict()), ) - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta - if isinstance(delta_chunk, str): - g.send(delta_chunk) - elif delta_chunk.content: - g.send(delta_chunk.content) + if not stream: + g.send(chat.choices[0].message.content) + else: + for chunk in chat: + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta + if isinstance(delta_chunk, str): + g.send(delta_chunk) + elif delta_chunk.content: + g.send(delta_chunk.content) except Exception as e: logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 999473e2..6444b14d 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,4 +1,3 @@ -import json import logging import math import queue @@ -24,6 +23,8 @@ model_to_prompt_size = { "gpt-4-0125-preview": 20000, "gpt-4-turbo-preview": 20000, "gpt-4o-mini": 20000, + "o1-preview": 20000, + "o1-mini": 20000, "TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, @@ -220,8 +221,9 @@ def truncate_messages( try: if loaded_model: encoder = loaded_model.tokenizer() - elif model_name.startswith("gpt-"): - encoder = tiktoken.encoding_for_model(model_name) + elif model_name.startswith("gpt-") or model_name.startswith("o1"): + # as tiktoken doesn't recognize o1 model series yet + encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) elif tokenizer_name: if tokenizer_name in state.pretrained_tokenizers: encoder = state.pretrained_tokenizers[tokenizer_name] @@ -278,7 +280,9 @@ def truncate_messages( ) if system_message: - system_message.role = "user" if "gemma-2" in model_name else "system" + # Default system message role is system. + # Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series. + system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system" return messages + [system_message] if system_message else messages