Access Claude and Gemini via GCP Vertex AI (#1134)

Support accessing Claude and Gemini AI models via Vertex AI on Google Cloud. 

See the documentation at docs.khoj.dev for setup details
This commit is contained in:
Debanjum
2025-03-23 16:26:02 +05:30
committed by GitHub
12 changed files with 205 additions and 27 deletions

View File

@@ -0,0 +1,26 @@
# Google Vertex AI
:::info
This is only helpful for self-hosted users. If you're using [Khoj Cloud](https://app.khoj.dev), you can directly use any of the pre-configured AI models.
:::
Khoj can use Google's Gemini and Anthropic's Claude family of AI models from [Vertex AI](https://cloud.google.com/vertex-ai) on Google Cloud. Explore Anthropic and Gemini AI models available on Vertex AI's [Model Garden](https://console.cloud.google.com/vertex-ai/model-garden).
## Setup
1. Follow [these instructions](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#before_you_begin) to use models on GCP Vertex AI.
2. Create [Service Account](https://console.cloud.google.com/apis/credentials/serviceaccountkey) credentials.
- Download the credentials keyfile in json format.
- Base64 encode the credentials json keyfile. For example by running the following command from your terminal:
`base64 -i <service_account_credentials_keyfile.json>`
3. Create a new [API Model API](http://localhost:42110/server/admin/database/aimodelapi/add) on your Khoj admin panel.
- **Name**: `Google Vertex` (or whatever friendly name you prefer).
- **Api Key**: `base64 encoded json keyfile` from step 2.
- **Api Base Url**: `https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}`
- MODEL_GCP_REGION: A region the AI model is available in. For example `us-east5` works for [Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions).
- YOUR_GCP_PROJECT_ID: Get your project id from the [Google cloud dashboard](https://console.cloud.google.com/home/dashboard)
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- **Name**: `claude-3-7-sonnet@20250219`. Any Claude or Gemini model on Vertex's Model Garden should work.
- **Model Type**: `Anthropic` or `Google`
- **Ai Model API**: *the Google Vertex Ai Model API you created in step 3*
- **Max prompt size**: `60000` (replace with the max prompt size of your model)
- **Tokenizer**: *Do not set*
5. Select the chat model on [your settings page](http://localhost:42110/settings) and start a conversation.

View File

@@ -0,0 +1,17 @@
# Generated by Django 5.0.13 on 2025-03-23 04:42
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0086_alter_texttoimagemodelconfig_model_type"),
]
operations = [
migrations.AlterField(
model_name="aimodelapi",
name="api_key",
field=models.CharField(max_length=4000),
),
]

View File

@@ -188,7 +188,7 @@ class Subscription(DbBaseModel):
class AiModelApi(DbBaseModel): class AiModelApi(DbBaseModel):
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
api_key = models.CharField(max_length=200) api_key = models.CharField(max_length=4000)
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True) api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
def __str__(self): def __str__(self):

View File

@@ -34,6 +34,7 @@ def extract_questions_anthropic(
model: Optional[str] = "claude-3-7-sonnet-latest", model: Optional[str] = "claude-3-7-sonnet-latest",
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
api_base_url=None,
temperature=0.7, temperature=0.7,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
@@ -102,6 +103,7 @@ def extract_questions_anthropic(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
response_type="json_object", response_type="json_object",
tracer=tracer, tracer=tracer,
) )
@@ -122,7 +124,9 @@ def extract_questions_anthropic(
return questions return questions
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}): def anthropic_send_message_to_model(
messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={}
):
""" """
Send message to model Send message to model
""" """
@@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex
system_prompt=system_prompt, system_prompt=system_prompt,
model_name=model, model_name=model,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
response_type=response_type, response_type=response_type,
deepthought=deepthought, deepthought=deepthought,
tracer=tracer, tracer=tracer,
@@ -148,6 +153,7 @@ def converse_anthropic(
conversation_log={}, conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest", model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
max_prompt_size=None, max_prompt_size=None,
@@ -238,6 +244,7 @@ def converse_anthropic(
model_name=model, model_name=model,
temperature=0, temperature=0,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,

View File

@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics, get_chat_usage_metrics,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
@@ -26,13 +27,25 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic] = {} anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000 MAX_REASONING_TOKENS_ANTHROPIC = 12000
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
@retry( @retry(
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),
@@ -45,17 +58,17 @@ def anthropic_completion_with_backoff(
model_name: str, model_name: str,
temperature=0, temperature=0,
api_key=None, api_key=None,
api_base_url: str = None,
model_kwargs=None, model_kwargs=None,
max_tokens=None, max_tokens=None,
response_type="text", response_type="text",
deepthought=False, deepthought=False,
tracer={}, tracer={},
) -> str: ) -> str:
if api_key not in anthropic_clients: client = anthropic_clients.get(api_key)
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client anthropic_clients[api_key] = client
else:
client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
aggregated_response = "" aggregated_response = ""
@@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff(
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url,
system_prompt, system_prompt,
max_prompt_size=None, max_prompt_size=None,
completion_func=None, completion_func=None,
@@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff(
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url,
max_prompt_size, max_prompt_size,
deepthought, deepthought,
model_kwargs, model_kwargs,
@@ -149,17 +164,17 @@ def anthropic_llm_thread(
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url=None,
max_prompt_size=None, max_prompt_size=None,
deepthought=False, deepthought=False,
model_kwargs=None, model_kwargs=None,
tracer={}, tracer={},
): ):
try: try:
if api_key not in anthropic_clients: client = anthropic_clients.get(api_key)
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client anthropic_clients[api_key] = client
else:
client: anthropic.Anthropic = anthropic_clients[api_key]
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC

View File

@@ -34,6 +34,7 @@ def extract_questions_gemini(
model: Optional[str] = "gemini-2.0-flash", model: Optional[str] = "gemini-2.0-flash",
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
api_base_url=None,
temperature=0.6, temperature=0.6,
max_tokens=None, max_tokens=None,
location_data: LocationData = None, location_data: LocationData = None,
@@ -97,7 +98,13 @@ def extract_questions_gemini(
messages.append(ChatMessage(content=system_prompt, role="system")) messages.append(ChatMessage(content=system_prompt, role="system"))
response = gemini_send_message_to_model( response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer messages,
api_key,
model,
api_base_url=api_base_url,
response_type="json_object",
temperature=temperature,
tracer=tracer,
) )
# Extract, Clean Message from Gemini's Response # Extract, Clean Message from Gemini's Response
@@ -120,6 +127,7 @@ def gemini_send_message_to_model(
messages, messages,
api_key, api_key,
model, model,
api_base_url=None,
response_type="text", response_type="text",
response_schema=None, response_schema=None,
temperature=0.6, temperature=0.6,
@@ -144,6 +152,7 @@ def gemini_send_message_to_model(
system_prompt=system_prompt, system_prompt=system_prompt,
model_name=model, model_name=model,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
temperature=temperature, temperature=temperature,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
tracer=tracer, tracer=tracer,
@@ -158,6 +167,7 @@ def converse_gemini(
conversation_log={}, conversation_log={},
model: Optional[str] = "gemini-2.0-flash", model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.6, temperature: float = 0.6,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
@@ -249,6 +259,7 @@ def converse_gemini(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
tracer=tracer, tracer=tracer,

View File

@@ -3,6 +3,7 @@ import os
import random import random
from copy import deepcopy from copy import deepcopy
from threading import Thread from threading import Thread
from typing import Dict
from google import genai from google import genai
from google.genai import errors as gerrors from google.genai import errors as gerrors
@@ -23,6 +24,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics, get_chat_usage_metrics,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
@@ -30,6 +32,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
gemini_clients: Dict[str, genai.Client] = {}
MAX_OUTPUT_TOKENS_GEMINI = 8192 MAX_OUTPUT_TOKENS_GEMINI = 8192
SAFETY_SETTINGS = [ SAFETY_SETTINGS = [
@@ -52,6 +55,17 @@ SAFETY_SETTINGS = [
] ]
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
@retry( @retry(
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),
@@ -59,9 +73,13 @@ SAFETY_SETTINGS = [
reraise=True, reraise=True,
) )
def gemini_completion_with_backoff( def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={} messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
) -> str: ) -> str:
client = genai.Client(api_key=api_key) client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig( config = gtypes.GenerateContentConfig(
system_instruction=system_prompt, system_instruction=system_prompt,
@@ -115,6 +133,7 @@ def gemini_chat_completion_with_backoff(
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url,
system_prompt, system_prompt,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
@@ -123,17 +142,29 @@ def gemini_chat_completion_with_backoff(
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=gemini_llm_thread, target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer), args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
) )
t.start() t.start()
return g return g
def gemini_llm_thread( def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {} g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
): ):
try: try:
client = genai.Client(api_key=api_key) client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig( config = gtypes.GenerateContentConfig(
system_instruction=system_prompt, system_instruction=system_prompt,

View File

@@ -55,7 +55,7 @@ def completion_with_backoff(
tracer: dict = {}, tracer: dict = {},
) -> str: ) -> str:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key) client = openai_clients.get(client_key)
if not client: if not client:
client = get_openai_client(openai_api_key, api_base_url) client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client openai_clients[client_key] = client
@@ -150,9 +150,8 @@ def llm_thread(
): ):
try: try:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
if client_key in openai_clients: client = openai_clients.get(client_key)
client = openai_clients[client_key] if not client:
else:
client = get_openai_client(openai_api_key, api_base_url) client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client openai_clients[client_key] = client

View File

@@ -463,12 +463,14 @@ async def extract_references_and_questions(
) )
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic( inferred_queries = extract_questions_anthropic(
defiltered_query, defiltered_query,
query_images=query_images, query_images=query_images,
model=chat_model_name, model=chat_model_name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user, user=user,
@@ -479,12 +481,14 @@ async def extract_references_and_questions(
) )
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini( inferred_queries = extract_questions_gemini(
defiltered_query, defiltered_query,
query_images=query_images, query_images=query_images,
model=chat_model_name, model=chat_model_name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
max_tokens=chat_model.max_prompt_size, max_tokens=chat_model.max_prompt_size,

View File

@@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper(
) )
elif model_type == ChatModel.ModelType.ANTHROPIC: elif model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=query, user_message=query,
context_message=context, context_message=context,
@@ -1239,10 +1240,12 @@ async def send_message_to_model_wrapper(
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
deepthought=deepthought, deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer, tracer=tracer,
) )
elif model_type == ChatModel.ModelType.GOOGLE: elif model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=query, user_message=query,
context_message=context, context_message=context,
@@ -1262,6 +1265,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
response_schema=response_schema, response_schema=response_schema,
api_base_url=api_base_url,
tracer=tracer, tracer=tracer,
) )
else: else:
@@ -1328,7 +1332,7 @@ def send_message_to_model_wrapper_sync(
query_files=query_files, query_files=query_files,
) )
openai_response = send_message_to_model( return send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url, api_base_url=api_base_url,
@@ -1338,10 +1342,9 @@ def send_message_to_model_wrapper_sync(
tracer=tracer, tracer=tracer,
) )
return openai_response
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
@@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync(
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
tracer=tracer, tracer=tracer,
@@ -1363,6 +1367,7 @@ def send_message_to_model_wrapper_sync(
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
@@ -1377,6 +1382,7 @@ def send_message_to_model_wrapper_sync(
return gemini_send_message_to_model( return gemini_send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
response_schema=response_schema, response_schema=response_schema,
@@ -1510,6 +1516,7 @@ def generate_chat_response(
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_anthropic( chat_response = converse_anthropic(
compiled_references, compiled_references,
query_to_run, query_to_run,
@@ -1519,6 +1526,7 @@ def generate_chat_response(
conversation_log=meta_log, conversation_log=meta_log,
model=chat_model.name, model=chat_model.name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
@@ -1536,6 +1544,7 @@ def generate_chat_response(
) )
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_gemini( chat_response = converse_gemini(
compiled_references, compiled_references,
query_to_run, query_to_run,
@@ -1544,6 +1553,7 @@ def generate_chat_response(
meta_log, meta_log,
model=chat_model.name, model=chat_model.name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,

View File

@@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"gemini-2.0-flash": {"input": 0.10, "output": 0.40}, "gemini-2.0-flash": {"input": 0.10, "output": 0.40},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0}, "claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0}, "claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0}, "claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
} }

View File

@@ -1,5 +1,6 @@
from __future__ import annotations # to avoid quoting type hints from __future__ import annotations # to avoid quoting type hints
import base64
import copy import copy
import datetime import datetime
import io import io
@@ -19,15 +20,18 @@ from itertools import islice
from os import path from os import path
from pathlib import Path from pathlib import Path
from time import perf_counter from time import perf_counter
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import ParseResult, urlparse
import openai import openai
import psutil import psutil
import pyjson5
import requests import requests
import torch import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from google.auth.credentials import Credentials
from google.oauth2 import service_account
from magika import Magika from magika import Magika
from PIL import Image from PIL import Image
from pytz import country_names, country_timezones from pytz import country_names, country_timezones
@@ -618,6 +622,58 @@ def get_chat_usage_metrics(
} }
class AiApiInfo(NamedTuple):
region: str
project: str
credentials: Credentials
api_key: str
def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]:
"""
Get GCP credentials from base64 encoded service account credentials json keyfile
"""
credentials_json = base64.b64decode(credentials_b64).decode("utf-8")
credentials_dict = pyjson5.loads(credentials_json)
credentials = service_account.Credentials.from_service_account_info(credentials_dict)
return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"])
def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]:
"""
Extract region, project id from GCP API url
API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}...
"""
# Extract region from hostname
hostname = parsed_api_url.netloc
region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else ""
# Extract project ID from path (e.g., "/v1/projects/my-project/...")
path_parts = parsed_api_url.path.split("/")
project_id = ""
for i, part in enumerate(path_parts):
if part == "projects" and i + 1 < len(path_parts):
project_id = path_parts[i + 1]
break
return region, project_id
def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo:
"""
Get the GCP Vertex or default AI API client info based on the API key and URL.
"""
region, project_id, credentials = None, None, None
# Check if AI model to be used via GCP Vertex API
parsed_api_url = urlparse(api_base_url)
if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"):
region, project_id = get_gcp_project_info(parsed_api_url)
credentials = get_gcp_credentials(api_key)
if credentials:
api_key = None
return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key)
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]: def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL""" """Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url) parsed_url = urlparse(api_base_url)