diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 6c2ffb8a..48c6515f 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -19,7 +19,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils.helpers import ( - get_ai_api_info, + get_anthropic_client, get_chat_usage_metrics, is_none_or_empty, is_promptrace_enabled, @@ -33,19 +33,6 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 MAX_REASONING_TOKENS_ANTHROPIC = 12000 -def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex: - api_info = get_ai_api_info(api_key, api_base_url) - if api_info.api_key: - client = anthropic.Anthropic(api_key=api_info.api_key) - else: - client = anthropic.AnthropicVertex( - region=api_info.region, - project_id=api_info.project, - credentials=api_info.credentials, - ) - return client - - @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 9a8b4132..b497edec 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -25,8 +25,8 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils.helpers import ( - get_ai_api_info, get_chat_usage_metrics, + get_gemini_client, is_none_or_empty, is_promptrace_enabled, ) @@ -62,17 +62,6 @@ SAFETY_SETTINGS = [ ] -def get_gemini_client(api_key, api_base_url=None) -> genai.Client: - api_info = get_ai_api_info(api_key, api_base_url) - return genai.Client( - location=api_info.region, - project=api_info.project, - credentials=api_info.credentials, - api_key=api_info.api_key, - vertexai=api_info.api_key is None, - ) - - @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index f0aa0cd6..4a756dcb 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -23,6 +23,7 @@ from time import perf_counter from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union from urllib.parse import ParseResult, urlparse +import anthropic import openai import psutil import pyjson5 @@ -30,6 +31,7 @@ import requests import torch from asgiref.sync import sync_to_async from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email +from google import genai from google.auth.credentials import Credentials from google.oauth2 import service_account from magika import Magika @@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o return client +def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]: + """Get OpenAI or AzureOpenAI client based on the API Base URL""" + parsed_url = urlparse(api_base_url) + if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"): + client = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=api_base_url, + api_version="2024-10-21", + ) + else: + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=api_base_url, + ) + return client + + +def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex: + api_info = get_ai_api_info(api_key, api_base_url) + if api_info.api_key: + client = anthropic.Anthropic(api_key=api_info.api_key) + else: + client = anthropic.AnthropicVertex( + region=api_info.region, + project_id=api_info.project, + credentials=api_info.credentials, + ) + return client + + +def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex: + api_info = get_ai_api_info(api_key, api_base_url) + if api_info.api_key: + client = anthropic.AsyncAnthropic(api_key=api_info.api_key) + else: + client = anthropic.AsyncAnthropicVertex( + region=api_info.region, + project_id=api_info.project, + credentials=api_info.credentials, + ) + return client + + +def get_gemini_client(api_key, api_base_url=None) -> genai.Client: + api_info = get_ai_api_info(api_key, api_base_url) + return genai.Client( + location=api_info.region, + project=api_info.project, + credentials=api_info.credentials, + api_key=api_info.api_key, + vertexai=api_info.api_key is None, + ) + + def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]: """Normalize, validate and check deliverability of email address""" lower_email = email.lower()