mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 05:39:12 +00:00
Support Azure OpenAI API endpoint (#1048)
OpenAI chat models deployed on Azure are (ironically) not OpenAI API compatible endpoints. This change enables using OpenAI chat models deployed on Azure with Khoj.
This commit is contained in:
@@ -19,7 +19,11 @@ from khoj.processor.conversation.utils import (
|
|||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
|
from khoj.utils.helpers import (
|
||||||
|
get_chat_usage_metrics,
|
||||||
|
get_openai_client,
|
||||||
|
is_promptrace_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,10 +55,7 @@ def completion_with_backoff(
|
|||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||||
if not client:
|
if not client:
|
||||||
client = openai.OpenAI(
|
client = get_openai_client(openai_api_key, api_base_url)
|
||||||
api_key=openai_api_key,
|
|
||||||
base_url=api_base_url,
|
|
||||||
)
|
|
||||||
openai_clients[client_key] = client
|
openai_clients[client_key] = client
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
@@ -161,14 +162,11 @@ def llm_thread(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
if client_key not in openai_clients:
|
if client_key in openai_clients:
|
||||||
client = openai.OpenAI(
|
|
||||||
api_key=openai_api_key,
|
|
||||||
base_url=api_base_url,
|
|
||||||
)
|
|
||||||
openai_clients[client_key] = client
|
|
||||||
else:
|
|
||||||
client = openai_clients[client_key]
|
client = openai_clients[client_key]
|
||||||
|
else:
|
||||||
|
client = get_openai_client(openai_api_key, api_base_url)
|
||||||
|
openai_clients[client_key] = client
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from time import perf_counter
|
|||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import openai
|
||||||
import psutil
|
import psutil
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
@@ -596,3 +597,20 @@ def get_chat_usage_metrics(
|
|||||||
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
||||||
"cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
|
"cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
|
||||||
|
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||||
|
parsed_url = urlparse(api_base_url)
|
||||||
|
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
|
||||||
|
client = openai.AzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base_url,
|
||||||
|
api_version="2024-10-21",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base_url,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|||||||
Reference in New Issue
Block a user