Create async get anthropic, openai client funcs, move to reusable package

This package is where the get openai client functions also reside.
This commit is contained in:
Debanjum
2025-03-24 15:14:19 +05:30
parent 973aded6c5
commit c93c0d982e
3 changed files with 58 additions and 26 deletions

View File

@@ -19,7 +19,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_ai_api_info, get_anthropic_client,
get_chat_usage_metrics, get_chat_usage_metrics,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
@@ -33,19 +33,6 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000 MAX_REASONING_TOKENS_ANTHROPIC = 12000
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
@retry( @retry(
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),

View File

@@ -25,8 +25,8 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics, get_chat_usage_metrics,
get_gemini_client,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
) )
@@ -62,17 +62,6 @@ SAFETY_SETTINGS = [
] ]
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
@retry( @retry(
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),

View File

@@ -23,6 +23,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
from urllib.parse import ParseResult, urlparse from urllib.parse import ParseResult, urlparse
import anthropic
import openai import openai
import psutil import psutil
import pyjson5 import pyjson5
@@ -30,6 +31,7 @@ import requests
import torch import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from google import genai
from google.auth.credentials import Credentials from google.auth.credentials import Credentials
from google.oauth2 import service_account from google.oauth2 import service_account
from magika import Magika from magika import Magika
@@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o
return client return client
def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url)
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base_url,
api_version="2024-10-21",
)
else:
client = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base_url,
)
return client
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.AsyncAnthropic(api_key=api_info.api_key)
else:
client = anthropic.AsyncAnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]: def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
"""Normalize, validate and check deliverability of email address""" """Normalize, validate and check deliverability of email address"""
lower_email = email.lower() lower_email = email.lower()