mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Create async get anthropic, openai client funcs, move to reusable package
This package is where the get openai client functions also reside.
This commit is contained in:
@@ -19,7 +19,7 @@ from khoj.processor.conversation.utils import (
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_ai_api_info,
|
||||
get_anthropic_client,
|
||||
get_chat_usage_metrics,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
@@ -33,19 +33,6 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
||||
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
||||
|
||||
|
||||
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
if api_info.api_key:
|
||||
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||
else:
|
||||
client = anthropic.AnthropicVertex(
|
||||
region=api_info.region,
|
||||
project_id=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
|
||||
@@ -25,8 +25,8 @@ from khoj.processor.conversation.utils import (
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_ai_api_info,
|
||||
get_chat_usage_metrics,
|
||||
get_gemini_client,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
)
|
||||
@@ -62,17 +62,6 @@ SAFETY_SETTINGS = [
|
||||
]
|
||||
|
||||
|
||||
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
return genai.Client(
|
||||
location=api_info.region,
|
||||
project=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
api_key=api_info.api_key,
|
||||
vertexai=api_info.api_key is None,
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
|
||||
@@ -23,6 +23,7 @@ from time import perf_counter
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import anthropic
|
||||
import openai
|
||||
import psutil
|
||||
import pyjson5
|
||||
@@ -30,6 +31,7 @@ import requests
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||
from google import genai
|
||||
from google.auth.credentials import Credentials
|
||||
from google.oauth2 import service_account
|
||||
from magika import Magika
|
||||
@@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o
|
||||
return client
|
||||
|
||||
|
||||
def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
|
||||
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||
parsed_url = urlparse(api_base_url)
|
||||
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
|
||||
client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base_url,
|
||||
api_version="2024-10-21",
|
||||
)
|
||||
else:
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
if api_info.api_key:
|
||||
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||
else:
|
||||
client = anthropic.AnthropicVertex(
|
||||
region=api_info.region,
|
||||
project_id=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
if api_info.api_key:
|
||||
client = anthropic.AsyncAnthropic(api_key=api_info.api_key)
|
||||
else:
|
||||
client = anthropic.AsyncAnthropicVertex(
|
||||
region=api_info.region,
|
||||
project_id=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
return genai.Client(
|
||||
location=api_info.region,
|
||||
project=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
api_key=api_info.api_key,
|
||||
vertexai=api_info.api_key is None,
|
||||
)
|
||||
|
||||
|
||||
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
|
||||
"""Normalize, validate and check deliverability of email address"""
|
||||
lower_email = email.lower()
|
||||
|
||||
Reference in New Issue
Block a user