Access Claude and Gemini via GCP Vertex AI (#1134)

Support accessing Claude and Gemini AI models via Vertex AI on Google Cloud. 

See the documentation at docs.khoj.dev for setup details
This commit is contained in:
Debanjum
2025-03-23 16:26:02 +05:30
committed by GitHub
12 changed files with 205 additions and 27 deletions

View File

@@ -0,0 +1,26 @@
# Google Vertex AI
:::info
This is only helpful for self-hosted users. If you're using [Khoj Cloud](https://app.khoj.dev), you can directly use any of the pre-configured AI models.
:::
Khoj can use Google's Gemini and Anthropic's Claude family of AI models from [Vertex AI](https://cloud.google.com/vertex-ai) on Google Cloud. Explore Anthropic and Gemini AI models available on Vertex AI's [Model Garden](https://console.cloud.google.com/vertex-ai/model-garden).
## Setup
1. Follow [these instructions](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#before_you_begin) to use models on GCP Vertex AI.
2. Create [Service Account](https://console.cloud.google.com/apis/credentials/serviceaccountkey) credentials.
- Download the credentials keyfile in json format.
- Base64 encode the credentials json keyfile. For example by running the following command from your terminal:
`base64 -i <service_account_credentials_keyfile.json>`
3. Create a new [API Model API](http://localhost:42110/server/admin/database/aimodelapi/add) on your Khoj admin panel.
- **Name**: `Google Vertex` (or whatever friendly name you prefer).
- **Api Key**: `base64 encoded json keyfile` from step 2.
- **Api Base Url**: `https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}`
- MODEL_GCP_REGION: A region the AI model is available in. For example `us-east5` works for [Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions).
- YOUR_GCP_PROJECT_ID: Get your project id from the [Google cloud dashboard](https://console.cloud.google.com/home/dashboard)
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- **Name**: `claude-3-7-sonnet@20250219`. Any Claude or Gemini model on Vertex's Model Garden should work.
- **Model Type**: `Anthropic` or `Google`
- **Ai Model API**: *the Google Vertex Ai Model API you created in step 3*
- **Max prompt size**: `60000` (replace with the max prompt size of your model)
- **Tokenizer**: *Do not set*
5. Select the chat model on [your settings page](http://localhost:42110/settings) and start a conversation.

View File

@@ -0,0 +1,17 @@
# Generated by Django 5.0.13 on 2025-03-23 04:42
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0086_alter_texttoimagemodelconfig_model_type"),
]
operations = [
migrations.AlterField(
model_name="aimodelapi",
name="api_key",
field=models.CharField(max_length=4000),
),
]

View File

@@ -188,7 +188,7 @@ class Subscription(DbBaseModel):
class AiModelApi(DbBaseModel):
name = models.CharField(max_length=200)
api_key = models.CharField(max_length=200)
api_key = models.CharField(max_length=4000)
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
def __str__(self):

View File

@@ -34,6 +34,7 @@ def extract_questions_anthropic(
model: Optional[str] = "claude-3-7-sonnet-latest",
conversation_log={},
api_key=None,
api_base_url=None,
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
@@ -102,6 +103,7 @@ def extract_questions_anthropic(
model_name=model,
temperature=temperature,
api_key=api_key,
api_base_url=api_base_url,
response_type="json_object",
tracer=tracer,
)
@@ -122,7 +124,9 @@ def extract_questions_anthropic(
return questions
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}):
def anthropic_send_message_to_model(
messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={}
):
"""
Send message to model
"""
@@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
api_base_url=api_base_url,
response_type=response_type,
deepthought=deepthought,
tracer=tracer,
@@ -148,6 +153,7 @@ def converse_anthropic(
conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
@@ -238,6 +244,7 @@ def converse_anthropic(
model_name=model,
temperature=0,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,

View File

@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics,
is_none_or_empty,
is_promptrace_enabled,
@@ -26,13 +27,25 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic] = {}
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
@@ -45,17 +58,17 @@ def anthropic_completion_with_backoff(
model_name: str,
temperature=0,
api_key=None,
api_base_url: str = None,
model_kwargs=None,
max_tokens=None,
response_type="text",
deepthought=False,
tracer={},
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
client = anthropic_clients.get(api_key)
if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client
else:
client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
aggregated_response = ""
@@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff(
model_name,
temperature,
api_key,
api_base_url,
system_prompt,
max_prompt_size=None,
completion_func=None,
@@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff(
model_name,
temperature,
api_key,
api_base_url,
max_prompt_size,
deepthought,
model_kwargs,
@@ -149,17 +164,17 @@ def anthropic_llm_thread(
model_name,
temperature,
api_key,
api_base_url=None,
max_prompt_size=None,
deepthought=False,
model_kwargs=None,
tracer={},
):
try:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
client = anthropic_clients.get(api_key)
if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client
else:
client: anthropic.Anthropic = anthropic_clients[api_key]
model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC

View File

@@ -34,6 +34,7 @@ def extract_questions_gemini(
model: Optional[str] = "gemini-2.0-flash",
conversation_log={},
api_key=None,
api_base_url=None,
temperature=0.6,
max_tokens=None,
location_data: LocationData = None,
@@ -97,7 +98,13 @@ def extract_questions_gemini(
messages.append(ChatMessage(content=system_prompt, role="system"))
response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
messages,
api_key,
model,
api_base_url=api_base_url,
response_type="json_object",
temperature=temperature,
tracer=tracer,
)
# Extract, Clean Message from Gemini's Response
@@ -120,6 +127,7 @@ def gemini_send_message_to_model(
messages,
api_key,
model,
api_base_url=None,
response_type="text",
response_schema=None,
temperature=0.6,
@@ -144,6 +152,7 @@ def gemini_send_message_to_model(
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
api_base_url=api_base_url,
temperature=temperature,
model_kwargs=model_kwargs,
tracer=tracer,
@@ -158,6 +167,7 @@ def converse_gemini(
conversation_log={},
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.6,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
@@ -249,6 +259,7 @@ def converse_gemini(
model_name=model,
temperature=temperature,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
tracer=tracer,

View File

@@ -3,6 +3,7 @@ import os
import random
from copy import deepcopy
from threading import Thread
from typing import Dict
from google import genai
from google.genai import errors as gerrors
@@ -23,6 +24,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics,
is_none_or_empty,
is_promptrace_enabled,
@@ -30,6 +32,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
gemini_clients: Dict[str, genai.Client] = {}
MAX_OUTPUT_TOKENS_GEMINI = 8192
SAFETY_SETTINGS = [
@@ -52,6 +55,17 @@ SAFETY_SETTINGS = [
]
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
@@ -59,9 +73,13 @@ SAFETY_SETTINGS = [
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
) -> str:
client = genai.Client(api_key=api_key)
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
@@ -115,6 +133,7 @@ def gemini_chat_completion_with_backoff(
model_name,
temperature,
api_key,
api_base_url,
system_prompt,
completion_func=None,
model_kwargs=None,
@@ -123,17 +142,29 @@ def gemini_chat_completion_with_backoff(
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
)
t.start()
return g
def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
):
try:
client = genai.Client(api_key=api_key)
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,

View File

@@ -55,7 +55,7 @@ def completion_with_backoff(
tracer: dict = {},
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
client = openai_clients.get(client_key)
if not client:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
@@ -150,9 +150,8 @@ def llm_thread(
):
try:
client_key = f"{openai_api_key}--{api_base_url}"
if client_key in openai_clients:
client = openai_clients[client_key]
else:
client = openai_clients.get(client_key)
if not client:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client

View File

@@ -463,12 +463,14 @@ async def extract_references_and_questions(
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log,
location_data=location_data,
user=user,
@@ -479,12 +481,14 @@ async def extract_references_and_questions(
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log,
location_data=location_data,
max_tokens=chat_model.max_prompt_size,

View File

@@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper(
)
elif model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
@@ -1239,10 +1240,12 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
@@ -1262,6 +1265,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
api_base_url=api_base_url,
tracer=tracer,
)
else:
@@ -1328,7 +1332,7 @@ def send_message_to_model_wrapper_sync(
query_files=query_files,
)
openai_response = send_message_to_model(
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
@@ -1338,10 +1342,9 @@ def send_message_to_model_wrapper_sync(
tracer=tracer,
)
return openai_response
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
@@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync(
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
@@ -1363,6 +1367,7 @@ def send_message_to_model_wrapper_sync(
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
@@ -1377,6 +1382,7 @@ def send_message_to_model_wrapper_sync(
return gemini_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
@@ -1510,6 +1516,7 @@ def generate_chat_response(
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_anthropic(
compiled_references,
query_to_run,
@@ -1519,6 +1526,7 @@ def generate_chat_response(
conversation_log=meta_log,
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size,
@@ -1536,6 +1544,7 @@ def generate_chat_response(
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_gemini(
compiled_references,
query_to_run,
@@ -1544,6 +1553,7 @@ def generate_chat_response(
meta_log,
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size,

View File

@@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
}

View File

@@ -1,5 +1,6 @@
from __future__ import annotations # to avoid quoting type hints
import base64
import copy
import datetime
import io
@@ -19,15 +20,18 @@ from itertools import islice
from os import path
from pathlib import Path
from time import perf_counter
from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urlparse
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
from urllib.parse import ParseResult, urlparse
import openai
import psutil
import pyjson5
import requests
import torch
from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from google.auth.credentials import Credentials
from google.oauth2 import service_account
from magika import Magika
from PIL import Image
from pytz import country_names, country_timezones
@@ -618,6 +622,58 @@ def get_chat_usage_metrics(
}
class AiApiInfo(NamedTuple):
region: str
project: str
credentials: Credentials
api_key: str
def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]:
"""
Get GCP credentials from base64 encoded service account credentials json keyfile
"""
credentials_json = base64.b64decode(credentials_b64).decode("utf-8")
credentials_dict = pyjson5.loads(credentials_json)
credentials = service_account.Credentials.from_service_account_info(credentials_dict)
return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"])
def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]:
"""
Extract region, project id from GCP API url
API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}...
"""
# Extract region from hostname
hostname = parsed_api_url.netloc
region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else ""
# Extract project ID from path (e.g., "/v1/projects/my-project/...")
path_parts = parsed_api_url.path.split("/")
project_id = ""
for i, part in enumerate(path_parts):
if part == "projects" and i + 1 < len(path_parts):
project_id = path_parts[i + 1]
break
return region, project_id
def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo:
"""
Get the GCP Vertex or default AI API client info based on the API key and URL.
"""
region, project_id, credentials = None, None, None
# Check if AI model to be used via GCP Vertex API
parsed_api_url = urlparse(api_base_url)
if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"):
region, project_id = get_gcp_project_info(parsed_api_url)
credentials = get_gcp_credentials(api_key)
if credentials:
api_key = None
return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key)
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url)