Support access to Anthropic models via GCP Vertex AI

Enable configuring a Khoj AI model API for Vertex AI using GCP credentials.

Specifically use the api key & api base url fields of the AI Model API
associated with the current chat model to extract gcp region, gcp
project id & credentials. This helps create a AnthropicVertex client.

The api key field should contain the GCP service account keyfile as a
base64 encoded string.

The api base url field should be of the form
`https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}`

Accepting GCP credentials via the AI model API makes it easy to use
across local and cloud environments. As it bypasses the need for a
separate service account key file on the Khoj server.
This commit is contained in:
Debanjum
2025-03-22 12:36:05 +05:30
parent 8bebcd5f81
commit 603c4bf2df
6 changed files with 101 additions and 13 deletions

View File

@@ -34,6 +34,7 @@ def extract_questions_anthropic(
model: Optional[str] = "claude-3-7-sonnet-latest", model: Optional[str] = "claude-3-7-sonnet-latest",
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
api_base_url=None,
temperature=0.7, temperature=0.7,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
@@ -102,6 +103,7 @@ def extract_questions_anthropic(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
response_type="json_object", response_type="json_object",
tracer=tracer, tracer=tracer,
) )
@@ -122,7 +124,9 @@ def extract_questions_anthropic(
return questions return questions
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}): def anthropic_send_message_to_model(
messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={}
):
""" """
Send message to model Send message to model
""" """
@@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex
system_prompt=system_prompt, system_prompt=system_prompt,
model_name=model, model_name=model,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
response_type=response_type, response_type=response_type,
deepthought=deepthought, deepthought=deepthought,
tracer=tracer, tracer=tracer,
@@ -148,6 +153,7 @@ def converse_anthropic(
conversation_log={}, conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest", model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
max_prompt_size=None, max_prompt_size=None,
@@ -238,6 +244,7 @@ def converse_anthropic(
model_name=model, model_name=model,
temperature=0, temperature=0,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,

View File

@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics, get_chat_usage_metrics,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
@@ -26,13 +27,25 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic] = {} anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000 MAX_REASONING_TOKENS_ANTHROPIC = 12000
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
@retry( @retry(
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),
@@ -45,17 +58,17 @@ def anthropic_completion_with_backoff(
model_name: str, model_name: str,
temperature=0, temperature=0,
api_key=None, api_key=None,
api_base_url: str = None,
model_kwargs=None, model_kwargs=None,
max_tokens=None, max_tokens=None,
response_type="text", response_type="text",
deepthought=False, deepthought=False,
tracer={}, tracer={},
) -> str: ) -> str:
if api_key not in anthropic_clients: client = anthropic_clients.get(api_key)
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client anthropic_clients[api_key] = client
else:
client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
aggregated_response = "" aggregated_response = ""
@@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff(
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url,
system_prompt, system_prompt,
max_prompt_size=None, max_prompt_size=None,
completion_func=None, completion_func=None,
@@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff(
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url,
max_prompt_size, max_prompt_size,
deepthought, deepthought,
model_kwargs, model_kwargs,
@@ -149,17 +164,17 @@ def anthropic_llm_thread(
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url=None,
max_prompt_size=None, max_prompt_size=None,
deepthought=False, deepthought=False,
model_kwargs=None, model_kwargs=None,
tracer={}, tracer={},
): ):
try: try:
if api_key not in anthropic_clients: client = anthropic_clients.get(api_key)
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client anthropic_clients[api_key] = client
else:
client: anthropic.Anthropic = anthropic_clients[api_key]
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC

View File

@@ -463,12 +463,14 @@ async def extract_references_and_questions(
) )
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic( inferred_queries = extract_questions_anthropic(
defiltered_query, defiltered_query,
query_images=query_images, query_images=query_images,
model=chat_model_name, model=chat_model_name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user, user=user,

View File

@@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper(
) )
elif model_type == ChatModel.ModelType.ANTHROPIC: elif model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=query, user_message=query,
context_message=context, context_message=context,
@@ -1239,6 +1240,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
deepthought=deepthought, deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer, tracer=tracer,
) )
elif model_type == ChatModel.ModelType.GOOGLE: elif model_type == ChatModel.ModelType.GOOGLE:
@@ -1342,6 +1344,7 @@ def send_message_to_model_wrapper_sync(
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
@@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync(
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
tracer=tracer, tracer=tracer,
@@ -1510,6 +1514,7 @@ def generate_chat_response(
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_anthropic( chat_response = converse_anthropic(
compiled_references, compiled_references,
query_to_run, query_to_run,
@@ -1519,6 +1524,7 @@ def generate_chat_response(
conversation_log=meta_log, conversation_log=meta_log,
model=chat_model.name, model=chat_model.name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,

View File

@@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"gemini-2.0-flash": {"input": 0.10, "output": 0.40}, "gemini-2.0-flash": {"input": 0.10, "output": 0.40},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0}, "claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0}, "claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0}, "claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
} }

View File

@@ -1,5 +1,6 @@
from __future__ import annotations # to avoid quoting type hints from __future__ import annotations # to avoid quoting type hints
import base64
import copy import copy
import datetime import datetime
import io import io
@@ -19,15 +20,18 @@ from itertools import islice
from os import path from os import path
from pathlib import Path from pathlib import Path
from time import perf_counter from time import perf_counter
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import ParseResult, urlparse
import openai import openai
import psutil import psutil
import pyjson5
import requests import requests
import torch import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from google.auth.credentials import Credentials
from google.oauth2 import service_account
from magika import Magika from magika import Magika
from PIL import Image from PIL import Image
from pytz import country_names, country_timezones from pytz import country_names, country_timezones
@@ -618,6 +622,58 @@ def get_chat_usage_metrics(
} }
class AiApiInfo(NamedTuple):
region: str
project: str
credentials: Credentials
api_key: str
def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]:
"""
Get GCP credentials from base64 encoded service account credentials json keyfile
"""
credentials_json = base64.b64decode(credentials_b64).decode("utf-8")
credentials_dict = pyjson5.loads(credentials_json)
credentials = service_account.Credentials.from_service_account_info(credentials_dict)
return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"])
def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]:
"""
Extract region, project id from GCP API url
API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}...
"""
# Extract region from hostname
hostname = parsed_api_url.netloc
region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else ""
# Extract project ID from path (e.g., "/v1/projects/my-project/...")
path_parts = parsed_api_url.path.split("/")
project_id = ""
for i, part in enumerate(path_parts):
if part == "projects" and i + 1 < len(path_parts):
project_id = path_parts[i + 1]
break
return region, project_id
def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo:
"""
Get the GCP Vertex or default AI API client info based on the API key and URL.
"""
region, project_id, credentials = None, None, None
# Check if AI model to be used via GCP Vertex API
parsed_api_url = urlparse(api_base_url)
if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"):
region, project_id = get_gcp_project_info(parsed_api_url)
credentials = get_gcp_credentials(api_key)
if credentials:
api_key = None
return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key)
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]: def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL""" """Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url) parsed_url = urlparse(api_base_url)