mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 21:29:11 +00:00
Access Claude and Gemini via GCP Vertex AI (#1134)
Support accessing Claude and Gemini AI models via Vertex AI on Google Cloud. See the documentation at docs.khoj.dev for setup details
This commit is contained in:
26
documentation/docs/advanced/gcp-vertex.md
Normal file
26
documentation/docs/advanced/gcp-vertex.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Google Vertex AI
|
||||
:::info
|
||||
This is only helpful for self-hosted users. If you're using [Khoj Cloud](https://app.khoj.dev), you can directly use any of the pre-configured AI models.
|
||||
:::
|
||||
|
||||
Khoj can use Google's Gemini and Anthropic's Claude family of AI models from [Vertex AI](https://cloud.google.com/vertex-ai) on Google Cloud. Explore Anthropic and Gemini AI models available on Vertex AI's [Model Garden](https://console.cloud.google.com/vertex-ai/model-garden).
|
||||
|
||||
## Setup
|
||||
1. Follow [these instructions](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#before_you_begin) to use models on GCP Vertex AI.
|
||||
2. Create [Service Account](https://console.cloud.google.com/apis/credentials/serviceaccountkey) credentials.
|
||||
- Download the credentials keyfile in json format.
|
||||
- Base64 encode the credentials json keyfile. For example by running the following command from your terminal:
|
||||
`base64 -i <service_account_credentials_keyfile.json>`
|
||||
3. Create a new [API Model API](http://localhost:42110/server/admin/database/aimodelapi/add) on your Khoj admin panel.
|
||||
- **Name**: `Google Vertex` (or whatever friendly name you prefer).
|
||||
- **Api Key**: `base64 encoded json keyfile` from step 2.
|
||||
- **Api Base Url**: `https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}`
|
||||
- MODEL_GCP_REGION: A region the AI model is available in. For example `us-east5` works for [Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions).
|
||||
- YOUR_GCP_PROJECT_ID: Get your project id from the [Google cloud dashboard](https://console.cloud.google.com/home/dashboard)
|
||||
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||
- **Name**: `claude-3-7-sonnet@20250219`. Any Claude or Gemini model on Vertex's Model Garden should work.
|
||||
- **Model Type**: `Anthropic` or `Google`
|
||||
- **Ai Model API**: *the Google Vertex Ai Model API you created in step 3*
|
||||
- **Max prompt size**: `60000` (replace with the max prompt size of your model)
|
||||
- **Tokenizer**: *Do not set*
|
||||
5. Select the chat model on [your settings page](http://localhost:42110/settings) and start a conversation.
|
||||
@@ -0,0 +1,17 @@
|
||||
# Generated by Django 5.0.13 on 2025-03-23 04:42
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0086_alter_texttoimagemodelconfig_model_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="aimodelapi",
|
||||
name="api_key",
|
||||
field=models.CharField(max_length=4000),
|
||||
),
|
||||
]
|
||||
@@ -188,7 +188,7 @@ class Subscription(DbBaseModel):
|
||||
|
||||
class AiModelApi(DbBaseModel):
|
||||
name = models.CharField(max_length=200)
|
||||
api_key = models.CharField(max_length=200)
|
||||
api_key = models.CharField(max_length=4000)
|
||||
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -34,6 +34,7 @@ def extract_questions_anthropic(
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
conversation_log={},
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
temperature=0.7,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
@@ -102,6 +103,7 @@ def extract_questions_anthropic(
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -122,7 +124,9 @@ def extract_questions_anthropic(
|
||||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}):
|
||||
def anthropic_send_message_to_model(
|
||||
messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
@@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
response_type=response_type,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
@@ -148,6 +153,7 @@ def converse_anthropic(
|
||||
conversation_log={},
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
@@ -238,6 +244,7 @@ def converse_anthropic(
|
||||
model_name=model,
|
||||
temperature=0,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
|
||||
@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_ai_api_info,
|
||||
get_chat_usage_metrics,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
@@ -26,13 +27,25 @@ from khoj.utils.helpers import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
anthropic_clients: Dict[str, anthropic.Anthropic] = {}
|
||||
|
||||
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
|
||||
|
||||
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
||||
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
||||
|
||||
|
||||
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
if api_info.api_key:
|
||||
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||
else:
|
||||
client = anthropic.AnthropicVertex(
|
||||
region=api_info.region,
|
||||
project_id=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
@@ -45,17 +58,17 @@ def anthropic_completion_with_backoff(
|
||||
model_name: str,
|
||||
temperature=0,
|
||||
api_key=None,
|
||||
api_base_url: str = None,
|
||||
model_kwargs=None,
|
||||
max_tokens=None,
|
||||
response_type="text",
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
) -> str:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
client = anthropic_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_anthropic_client(api_key, api_base_url)
|
||||
anthropic_clients[api_key] = client
|
||||
else:
|
||||
client = anthropic_clients[api_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
aggregated_response = ""
|
||||
@@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff(
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
system_prompt,
|
||||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
@@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff(
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
max_prompt_size,
|
||||
deepthought,
|
||||
model_kwargs,
|
||||
@@ -149,17 +164,17 @@ def anthropic_llm_thread(
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url=None,
|
||||
max_prompt_size=None,
|
||||
deepthought=False,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
try:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
client = anthropic_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_anthropic_client(api_key, api_base_url)
|
||||
anthropic_clients[api_key] = client
|
||||
else:
|
||||
client: anthropic.Anthropic = anthropic_clients[api_key]
|
||||
|
||||
model_kwargs = model_kwargs or dict()
|
||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
|
||||
@@ -34,6 +34,7 @@ def extract_questions_gemini(
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
conversation_log={},
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
temperature=0.6,
|
||||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
@@ -97,7 +98,13 @@ def extract_questions_gemini(
|
||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
api_base_url=api_base_url,
|
||||
response_type="json_object",
|
||||
temperature=temperature,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
@@ -120,6 +127,7 @@ def gemini_send_message_to_model(
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
api_base_url=None,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
temperature=0.6,
|
||||
@@ -144,6 +152,7 @@ def gemini_send_message_to_model(
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
tracer=tracer,
|
||||
@@ -158,6 +167,7 @@ def converse_gemini(
|
||||
conversation_log={},
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 0.6,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
@@ -249,6 +259,7 @@ def converse_gemini(
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
tracer=tracer,
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from threading import Thread
|
||||
from typing import Dict
|
||||
|
||||
from google import genai
|
||||
from google.genai import errors as gerrors
|
||||
@@ -23,6 +24,7 @@ from khoj.processor.conversation.utils import (
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_ai_api_info,
|
||||
get_chat_usage_metrics,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
@@ -30,6 +32,7 @@ from khoj.utils.helpers import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
gemini_clients: Dict[str, genai.Client] = {}
|
||||
|
||||
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||
SAFETY_SETTINGS = [
|
||||
@@ -52,6 +55,17 @@ SAFETY_SETTINGS = [
|
||||
]
|
||||
|
||||
|
||||
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
return genai.Client(
|
||||
location=api_info.region,
|
||||
project=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
api_key=api_info.api_key,
|
||||
vertexai=api_info.api_key is None,
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
@@ -59,9 +73,13 @@ SAFETY_SETTINGS = [
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
||||
) -> str:
|
||||
client = genai.Client(api_key=api_key)
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
@@ -115,6 +133,7 @@ def gemini_chat_completion_with_backoff(
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
system_prompt,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
@@ -123,17 +142,29 @@ def gemini_chat_completion_with_backoff(
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||
g,
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
|
||||
@@ -55,7 +55,7 @@ def completion_with_backoff(
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||
client = openai_clients.get(client_key)
|
||||
if not client:
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
@@ -150,9 +150,8 @@ def llm_thread(
|
||||
):
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
if client_key in openai_clients:
|
||||
client = openai_clients[client_key]
|
||||
else:
|
||||
client = openai_clients.get(client_key)
|
||||
if not client:
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
|
||||
|
||||
@@ -463,12 +463,14 @@ async def extract_references_and_questions(
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions_anthropic(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
@@ -479,12 +481,14 @@ async def extract_references_and_questions(
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
max_tokens=chat_model.max_prompt_size,
|
||||
|
||||
@@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper(
|
||||
)
|
||||
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
@@ -1239,10 +1240,12 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
@@ -1262,6 +1265,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
@@ -1328,7 +1332,7 @@ def send_message_to_model_wrapper_sync(
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
@@ -1338,10 +1342,9 @@ def send_message_to_model_wrapper_sync(
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
@@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync(
|
||||
return anthropic_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
@@ -1363,6 +1367,7 @@ def send_message_to_model_wrapper_sync(
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
@@ -1377,6 +1382,7 @@ def send_message_to_model_wrapper_sync(
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
@@ -1510,6 +1516,7 @@ def generate_chat_response(
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response = converse_anthropic(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
@@ -1519,6 +1526,7 @@ def generate_chat_response(
|
||||
conversation_log=meta_log,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
@@ -1536,6 +1544,7 @@ def generate_chat_response(
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
@@ -1544,6 +1553,7 @@ def generate_chat_response(
|
||||
meta_log,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
|
||||
@@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
|
||||
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
|
||||
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0},
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations # to avoid quoting type hints
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import datetime
|
||||
import io
|
||||
@@ -19,15 +20,18 @@ from itertools import islice
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import openai
|
||||
import psutil
|
||||
import pyjson5
|
||||
import requests
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||
from google.auth.credentials import Credentials
|
||||
from google.oauth2 import service_account
|
||||
from magika import Magika
|
||||
from PIL import Image
|
||||
from pytz import country_names, country_timezones
|
||||
@@ -618,6 +622,58 @@ def get_chat_usage_metrics(
|
||||
}
|
||||
|
||||
|
||||
class AiApiInfo(NamedTuple):
|
||||
region: str
|
||||
project: str
|
||||
credentials: Credentials
|
||||
api_key: str
|
||||
|
||||
|
||||
def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]:
|
||||
"""
|
||||
Get GCP credentials from base64 encoded service account credentials json keyfile
|
||||
"""
|
||||
credentials_json = base64.b64decode(credentials_b64).decode("utf-8")
|
||||
credentials_dict = pyjson5.loads(credentials_json)
|
||||
credentials = service_account.Credentials.from_service_account_info(credentials_dict)
|
||||
return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||
|
||||
|
||||
def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]:
|
||||
"""
|
||||
Extract region, project id from GCP API url
|
||||
API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}...
|
||||
"""
|
||||
# Extract region from hostname
|
||||
hostname = parsed_api_url.netloc
|
||||
region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else ""
|
||||
|
||||
# Extract project ID from path (e.g., "/v1/projects/my-project/...")
|
||||
path_parts = parsed_api_url.path.split("/")
|
||||
project_id = ""
|
||||
for i, part in enumerate(path_parts):
|
||||
if part == "projects" and i + 1 < len(path_parts):
|
||||
project_id = path_parts[i + 1]
|
||||
break
|
||||
|
||||
return region, project_id
|
||||
|
||||
|
||||
def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo:
|
||||
"""
|
||||
Get the GCP Vertex or default AI API client info based on the API key and URL.
|
||||
"""
|
||||
region, project_id, credentials = None, None, None
|
||||
# Check if AI model to be used via GCP Vertex API
|
||||
parsed_api_url = urlparse(api_base_url)
|
||||
if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"):
|
||||
region, project_id = get_gcp_project_info(parsed_api_url)
|
||||
credentials = get_gcp_credentials(api_key)
|
||||
if credentials:
|
||||
api_key = None
|
||||
return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key)
|
||||
|
||||
|
||||
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
|
||||
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||
parsed_url = urlparse(api_base_url)
|
||||
|
||||
Reference in New Issue
Block a user