mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Support Azure OpenAI API endpoint (#1048)
OpenAI chat models deployed on Azure are (ironically) not OpenAI API compatible endpoints. This change enables using OpenAI chat models deployed on Azure with Khoj.
This commit is contained in:
@@ -19,7 +19,11 @@ from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
|
||||
from khoj.utils.helpers import (
|
||||
get_chat_usage_metrics,
|
||||
get_openai_client,
|
||||
is_promptrace_enabled,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,10 +55,7 @@ def completion_with_backoff(
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||
if not client:
|
||||
client = openai.OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
@@ -161,14 +162,11 @@ def llm_thread(
|
||||
):
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
if client_key not in openai_clients:
|
||||
client = openai.OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
openai_clients[client_key] = client
|
||||
else:
|
||||
if client_key in openai_clients:
|
||||
client = openai_clients[client_key]
|
||||
else:
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from time import perf_counter
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
import psutil
|
||||
import requests
|
||||
import torch
|
||||
@@ -596,3 +597,20 @@ def get_chat_usage_metrics(
|
||||
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
||||
"cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
|
||||
}
|
||||
|
||||
|
||||
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
|
||||
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||
parsed_url = urlparse(api_base_url)
|
||||
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
|
||||
client = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base_url,
|
||||
api_version="2024-10-21",
|
||||
)
|
||||
else:
|
||||
client = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
return client
|
||||
|
||||
Reference in New Issue
Block a user