mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Add support for the newly released OpenAI O1 model series for preview
The O1 series doesn't seem to support streaming, stop words or temperature, response_format currently.
This commit is contained in:
@@ -45,15 +45,28 @@ def completion_with_backoff(
|
|||||||
openai_clients[client_key] = client
|
openai_clients[client_key] = client
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
stream = True
|
||||||
|
|
||||||
|
# Update request parameters for compatability with o1 model series
|
||||||
|
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
||||||
|
if model.startswith("o1"):
|
||||||
|
stream = False
|
||||||
|
temperature = 1
|
||||||
|
model_kwargs.pop("stop", None)
|
||||||
|
model_kwargs.pop("response_format", None)
|
||||||
|
|
||||||
chat = client.chat.completions.create(
|
chat = client.chat.completions.create(
|
||||||
stream=True,
|
stream=stream,
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
return chat.choices[0].message.content
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
for chunk in chat:
|
for chunk in chat:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
@@ -112,9 +125,18 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
|||||||
client: openai.OpenAI = openai_clients[client_key]
|
client: openai.OpenAI = openai_clients[client_key]
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
stream = True
|
||||||
|
|
||||||
|
# Update request parameters for compatability with o1 model series
|
||||||
|
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
||||||
|
if model_name.startswith("o1"):
|
||||||
|
stream = False
|
||||||
|
temperature = 1
|
||||||
|
model_kwargs.pop("stop", None)
|
||||||
|
model_kwargs.pop("response_format", None)
|
||||||
|
|
||||||
chat = client.chat.completions.create(
|
chat = client.chat.completions.create(
|
||||||
stream=True,
|
stream=stream,
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -122,6 +144,9 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
|||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
g.send(chat.choices[0].message.content)
|
||||||
|
else:
|
||||||
for chunk in chat:
|
for chunk in chat:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import queue
|
import queue
|
||||||
@@ -24,6 +23,8 @@ model_to_prompt_size = {
|
|||||||
"gpt-4-0125-preview": 20000,
|
"gpt-4-0125-preview": 20000,
|
||||||
"gpt-4-turbo-preview": 20000,
|
"gpt-4-turbo-preview": 20000,
|
||||||
"gpt-4o-mini": 20000,
|
"gpt-4o-mini": 20000,
|
||||||
|
"o1-preview": 20000,
|
||||||
|
"o1-mini": 20000,
|
||||||
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
|
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
|
||||||
"NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
|
"NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
|
||||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
||||||
@@ -220,8 +221,9 @@ def truncate_messages(
|
|||||||
try:
|
try:
|
||||||
if loaded_model:
|
if loaded_model:
|
||||||
encoder = loaded_model.tokenizer()
|
encoder = loaded_model.tokenizer()
|
||||||
elif model_name.startswith("gpt-"):
|
elif model_name.startswith("gpt-") or model_name.startswith("o1"):
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
# as tiktoken doesn't recognize o1 model series yet
|
||||||
|
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||||
elif tokenizer_name:
|
elif tokenizer_name:
|
||||||
if tokenizer_name in state.pretrained_tokenizers:
|
if tokenizer_name in state.pretrained_tokenizers:
|
||||||
encoder = state.pretrained_tokenizers[tokenizer_name]
|
encoder = state.pretrained_tokenizers[tokenizer_name]
|
||||||
@@ -278,7 +280,9 @@ def truncate_messages(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if system_message:
|
if system_message:
|
||||||
system_message.role = "user" if "gemma-2" in model_name else "system"
|
# Default system message role is system.
|
||||||
|
# Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
|
||||||
|
system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
|
||||||
return messages + [system_message] if system_message else messages
|
return messages + [system_message] if system_message else messages
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user