mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Create async get anthropic, openai client funcs, move to reusable package
This package is where the get openai client functions also reside.
This commit is contained in:
@@ -19,7 +19,7 @@ from khoj.processor.conversation.utils import (
|
|||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
get_ai_api_info,
|
get_anthropic_client,
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
@@ -33,19 +33,6 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
|||||||
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
||||||
|
|
||||||
|
|
||||||
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
|
||||||
api_info = get_ai_api_info(api_key, api_base_url)
|
|
||||||
if api_info.api_key:
|
|
||||||
client = anthropic.Anthropic(api_key=api_info.api_key)
|
|
||||||
else:
|
|
||||||
client = anthropic.AnthropicVertex(
|
|
||||||
region=api_info.region,
|
|
||||||
project_id=api_info.project,
|
|
||||||
credentials=api_info.credentials,
|
|
||||||
)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ from khoj.processor.conversation.utils import (
|
|||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
get_ai_api_info,
|
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
|
get_gemini_client,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
)
|
)
|
||||||
@@ -62,17 +62,6 @@ SAFETY_SETTINGS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
|
||||||
api_info = get_ai_api_info(api_key, api_base_url)
|
|
||||||
return genai.Client(
|
|
||||||
location=api_info.region,
|
|
||||||
project=api_info.project,
|
|
||||||
credentials=api_info.credentials,
|
|
||||||
api_key=api_info.api_key,
|
|
||||||
vertexai=api_info.api_key is None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from time import perf_counter
|
|||||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
||||||
from urllib.parse import ParseResult, urlparse
|
from urllib.parse import ParseResult, urlparse
|
||||||
|
|
||||||
|
import anthropic
|
||||||
import openai
|
import openai
|
||||||
import psutil
|
import psutil
|
||||||
import pyjson5
|
import pyjson5
|
||||||
@@ -30,6 +31,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||||
|
from google import genai
|
||||||
from google.auth.credentials import Credentials
|
from google.auth.credentials import Credentials
|
||||||
from google.oauth2 import service_account
|
from google.oauth2 import service_account
|
||||||
from magika import Magika
|
from magika import Magika
|
||||||
@@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
|
||||||
|
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||||
|
parsed_url = urlparse(api_base_url)
|
||||||
|
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
|
||||||
|
client = openai.AsyncAzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base_url,
|
||||||
|
api_version="2024-10-21",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base_url,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||||
|
api_info = get_ai_api_info(api_key, api_base_url)
|
||||||
|
if api_info.api_key:
|
||||||
|
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||||
|
else:
|
||||||
|
client = anthropic.AnthropicVertex(
|
||||||
|
region=api_info.region,
|
||||||
|
project_id=api_info.project,
|
||||||
|
credentials=api_info.credentials,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex:
|
||||||
|
api_info = get_ai_api_info(api_key, api_base_url)
|
||||||
|
if api_info.api_key:
|
||||||
|
client = anthropic.AsyncAnthropic(api_key=api_info.api_key)
|
||||||
|
else:
|
||||||
|
client = anthropic.AsyncAnthropicVertex(
|
||||||
|
region=api_info.region,
|
||||||
|
project_id=api_info.project,
|
||||||
|
credentials=api_info.credentials,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||||
|
api_info = get_ai_api_info(api_key, api_base_url)
|
||||||
|
return genai.Client(
|
||||||
|
location=api_info.region,
|
||||||
|
project=api_info.project,
|
||||||
|
credentials=api_info.credentials,
|
||||||
|
api_key=api_info.api_key,
|
||||||
|
vertexai=api_info.api_key is None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
|
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
|
||||||
"""Normalize, validate and check deliverability of email address"""
|
"""Normalize, validate and check deliverability of email address"""
|
||||||
lower_email = email.lower()
|
lower_email = email.lower()
|
||||||
|
|||||||
Reference in New Issue
Block a user