Support Azure OpenAI API endpoint (#1048)

OpenAI chat models deployed on Azure are (ironically) not OpenAI API compatible endpoints.
This change enables using OpenAI chat models deployed on Azure with Khoj.
This commit is contained in:
Debanjum
2025-01-10 23:35:03 +07:00
committed by GitHub
parent bac90ad69d
commit 3cc6597b49
2 changed files with 28 additions and 12 deletions

View File

@@ -19,7 +19,11 @@ from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
from khoj.utils.helpers import (
get_chat_usage_metrics,
get_openai_client,
is_promptrace_enabled,
)
logger = logging.getLogger(__name__)
@@ -51,10 +55,7 @@ def completion_with_backoff(
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
if not client:
client = openai.OpenAI(
api_key=openai_api_key,
base_url=api_base_url,
)
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
@@ -161,14 +162,11 @@ def llm_thread(
):
try:
client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients:
client = openai.OpenAI(
api_key=openai_api_key,
base_url=api_base_url,
)
openai_clients[client_key] = client
else:
if client_key in openai_clients:
client = openai_clients[client_key]
else:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]

View File

@@ -22,6 +22,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urlparse
import openai
import psutil
import requests
import torch
@@ -596,3 +597,20 @@ def get_chat_usage_metrics(
"output_tokens": prev_usage["output_tokens"] + output_tokens,
"cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
}
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url)
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
client = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base_url,
api_version="2024-10-21",
)
else:
client = openai.OpenAI(
api_key=api_key,
base_url=api_base_url,
)
return client