Support access to Anthropic models via GCP Vertex AI

Enable configuring a Khoj AI model API for Vertex AI using GCP credentials.

Specifically use the api key & api base url fields of the AI Model API
associated with the current chat model to extract gcp region, gcp
project id & credentials. This helps create a AnthropicVertex client.

The api key field should contain the GCP service account keyfile as a
base64 encoded string.

The api base url field should be of the form
`https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}`

Accepting GCP credentials via the AI model API makes it easy to use
across local and cloud environments. As it bypasses the need for a
separate service account key file on the Khoj server.
This commit is contained in:
Debanjum
2025-03-22 12:36:05 +05:30
parent 8bebcd5f81
commit 603c4bf2df
6 changed files with 101 additions and 13 deletions

View File

@@ -34,6 +34,7 @@ def extract_questions_anthropic(
model: Optional[str] = "claude-3-7-sonnet-latest",
conversation_log={},
api_key=None,
api_base_url=None,
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
@@ -102,6 +103,7 @@ def extract_questions_anthropic(
model_name=model,
temperature=temperature,
api_key=api_key,
api_base_url=api_base_url,
response_type="json_object",
tracer=tracer,
)
@@ -122,7 +124,9 @@ def extract_questions_anthropic(
return questions
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}):
def anthropic_send_message_to_model(
messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={}
):
"""
Send message to model
"""
@@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
api_base_url=api_base_url,
response_type=response_type,
deepthought=deepthought,
tracer=tracer,
@@ -148,6 +153,7 @@ def converse_anthropic(
conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
@@ -238,6 +244,7 @@ def converse_anthropic(
model_name=model,
temperature=0,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,

View File

@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics,
is_none_or_empty,
is_promptrace_enabled,
@@ -26,13 +27,25 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic] = {}
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
@@ -45,17 +58,17 @@ def anthropic_completion_with_backoff(
model_name: str,
temperature=0,
api_key=None,
api_base_url: str = None,
model_kwargs=None,
max_tokens=None,
response_type="text",
deepthought=False,
tracer={},
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
client = anthropic_clients.get(api_key)
if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client
else:
client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
aggregated_response = ""
@@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff(
model_name,
temperature,
api_key,
api_base_url,
system_prompt,
max_prompt_size=None,
completion_func=None,
@@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff(
model_name,
temperature,
api_key,
api_base_url,
max_prompt_size,
deepthought,
model_kwargs,
@@ -149,17 +164,17 @@ def anthropic_llm_thread(
model_name,
temperature,
api_key,
api_base_url=None,
max_prompt_size=None,
deepthought=False,
model_kwargs=None,
tracer={},
):
try:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
client = anthropic_clients.get(api_key)
if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client
else:
client: anthropic.Anthropic = anthropic_clients[api_key]
model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC

View File

@@ -463,12 +463,14 @@ async def extract_references_and_questions(
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log,
location_data=location_data,
user=user,

View File

@@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper(
)
elif model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
@@ -1239,6 +1240,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModel.ModelType.GOOGLE:
@@ -1342,6 +1344,7 @@ def send_message_to_model_wrapper_sync(
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
@@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync(
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
@@ -1510,6 +1514,7 @@ def generate_chat_response(
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_anthropic(
compiled_references,
query_to_run,
@@ -1519,6 +1524,7 @@ def generate_chat_response(
conversation_log=meta_log,
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size,

View File

@@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
}

View File

@@ -1,5 +1,6 @@
from __future__ import annotations # to avoid quoting type hints
import base64
import copy
import datetime
import io
@@ -19,15 +20,18 @@ from itertools import islice
from os import path
from pathlib import Path
from time import perf_counter
from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urlparse
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
from urllib.parse import ParseResult, urlparse
import openai
import psutil
import pyjson5
import requests
import torch
from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from google.auth.credentials import Credentials
from google.oauth2 import service_account
from magika import Magika
from PIL import Image
from pytz import country_names, country_timezones
@@ -618,6 +622,58 @@ def get_chat_usage_metrics(
}
class AiApiInfo(NamedTuple):
region: str
project: str
credentials: Credentials
api_key: str
def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]:
"""
Get GCP credentials from base64 encoded service account credentials json keyfile
"""
credentials_json = base64.b64decode(credentials_b64).decode("utf-8")
credentials_dict = pyjson5.loads(credentials_json)
credentials = service_account.Credentials.from_service_account_info(credentials_dict)
return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"])
def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]:
"""
Extract region, project id from GCP API url
API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}...
"""
# Extract region from hostname
hostname = parsed_api_url.netloc
region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else ""
# Extract project ID from path (e.g., "/v1/projects/my-project/...")
path_parts = parsed_api_url.path.split("/")
project_id = ""
for i, part in enumerate(path_parts):
if part == "projects" and i + 1 < len(path_parts):
project_id = path_parts[i + 1]
break
return region, project_id
def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo:
"""
Get the GCP Vertex or default AI API client info based on the API key and URL.
"""
region, project_id, credentials = None, None, None
# Check if AI model to be used via GCP Vertex API
parsed_api_url = urlparse(api_base_url)
if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"):
region, project_id = get_gcp_project_info(parsed_api_url)
credentials = get_gcp_credentials(api_key)
if credentials:
api_key = None
return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key)
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url)