Create async get anthropic, openai client funcs, move to reusable package

This package is where the get openai client functions also reside.
This commit is contained in:
Debanjum
2025-03-24 15:14:19 +05:30
parent 973aded6c5
commit c93c0d982e
3 changed files with 58 additions and 26 deletions

View File

@@ -19,7 +19,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils.helpers import (
get_ai_api_info,
get_anthropic_client,
get_chat_usage_metrics,
is_none_or_empty,
is_promptrace_enabled,
@@ -33,19 +33,6 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),

View File

@@ -25,8 +25,8 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics,
get_gemini_client,
is_none_or_empty,
is_promptrace_enabled,
)
@@ -62,17 +62,6 @@ SAFETY_SETTINGS = [
]
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),

View File

@@ -23,6 +23,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
from urllib.parse import ParseResult, urlparse
import anthropic
import openai
import psutil
import pyjson5
@@ -30,6 +31,7 @@ import requests
import torch
from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from google import genai
from google.auth.credentials import Credentials
from google.oauth2 import service_account
from magika import Magika
@@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o
return client
def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url)
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base_url,
api_version="2024-10-21",
)
else:
client = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base_url,
)
return client
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.AsyncAnthropic(api_key=api_info.api_key)
else:
client = anthropic.AsyncAnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
"""Normalize, validate and check deliverability of email address"""
lower_email = email.lower()