mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Support access to Anthropic models via GCP Vertex AI
Enable configuring a Khoj AI model API for Vertex AI using GCP credentials. Specifically use the api key & api base url fields of the AI Model API associated with the current chat model to extract gcp region, gcp project id & credentials. This helps create a AnthropicVertex client. The api key field should contain the GCP service account keyfile as a base64 encoded string. The api base url field should be of the form `https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}` Accepting GCP credentials via the AI model API makes it easy to use across local and cloud environments. As it bypasses the need for a separate service account key file on the Khoj server.
This commit is contained in:
@@ -34,6 +34,7 @@ def extract_questions_anthropic(
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
conversation_log={},
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
temperature=0.7,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
@@ -102,6 +103,7 @@ def extract_questions_anthropic(
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -122,7 +124,9 @@ def extract_questions_anthropic(
|
||||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}):
|
||||
def anthropic_send_message_to_model(
|
||||
messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
@@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
response_type=response_type,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
@@ -148,6 +153,7 @@ def converse_anthropic(
|
||||
conversation_log={},
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
@@ -238,6 +244,7 @@ def converse_anthropic(
|
||||
model_name=model,
|
||||
temperature=0,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
|
||||
@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_ai_api_info,
|
||||
get_chat_usage_metrics,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
@@ -26,13 +27,25 @@ from khoj.utils.helpers import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
anthropic_clients: Dict[str, anthropic.Anthropic] = {}
|
||||
|
||||
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
|
||||
|
||||
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
||||
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
||||
|
||||
|
||||
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
if api_info.api_key:
|
||||
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||
else:
|
||||
client = anthropic.AnthropicVertex(
|
||||
region=api_info.region,
|
||||
project_id=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
@@ -45,17 +58,17 @@ def anthropic_completion_with_backoff(
|
||||
model_name: str,
|
||||
temperature=0,
|
||||
api_key=None,
|
||||
api_base_url: str = None,
|
||||
model_kwargs=None,
|
||||
max_tokens=None,
|
||||
response_type="text",
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
) -> str:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
client = anthropic_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_anthropic_client(api_key, api_base_url)
|
||||
anthropic_clients[api_key] = client
|
||||
else:
|
||||
client = anthropic_clients[api_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
aggregated_response = ""
|
||||
@@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff(
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
system_prompt,
|
||||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
@@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff(
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
max_prompt_size,
|
||||
deepthought,
|
||||
model_kwargs,
|
||||
@@ -149,17 +164,17 @@ def anthropic_llm_thread(
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url=None,
|
||||
max_prompt_size=None,
|
||||
deepthought=False,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
try:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
client = anthropic_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_anthropic_client(api_key, api_base_url)
|
||||
anthropic_clients[api_key] = client
|
||||
else:
|
||||
client: anthropic.Anthropic = anthropic_clients[api_key]
|
||||
|
||||
model_kwargs = model_kwargs or dict()
|
||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
|
||||
@@ -463,12 +463,14 @@ async def extract_references_and_questions(
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions_anthropic(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
|
||||
@@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper(
|
||||
)
|
||||
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
@@ -1239,6 +1240,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||
@@ -1342,6 +1344,7 @@ def send_message_to_model_wrapper_sync(
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
@@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync(
|
||||
return anthropic_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
@@ -1510,6 +1514,7 @@ def generate_chat_response(
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response = converse_anthropic(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
@@ -1519,6 +1524,7 @@ def generate_chat_response(
|
||||
conversation_log=meta_log,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
|
||||
@@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
|
||||
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
|
||||
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0},
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations # to avoid quoting type hints
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import datetime
|
||||
import io
|
||||
@@ -19,15 +20,18 @@ from itertools import islice
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import openai
|
||||
import psutil
|
||||
import pyjson5
|
||||
import requests
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||
from google.auth.credentials import Credentials
|
||||
from google.oauth2 import service_account
|
||||
from magika import Magika
|
||||
from PIL import Image
|
||||
from pytz import country_names, country_timezones
|
||||
@@ -618,6 +622,58 @@ def get_chat_usage_metrics(
|
||||
}
|
||||
|
||||
|
||||
class AiApiInfo(NamedTuple):
|
||||
region: str
|
||||
project: str
|
||||
credentials: Credentials
|
||||
api_key: str
|
||||
|
||||
|
||||
def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]:
|
||||
"""
|
||||
Get GCP credentials from base64 encoded service account credentials json keyfile
|
||||
"""
|
||||
credentials_json = base64.b64decode(credentials_b64).decode("utf-8")
|
||||
credentials_dict = pyjson5.loads(credentials_json)
|
||||
credentials = service_account.Credentials.from_service_account_info(credentials_dict)
|
||||
return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||
|
||||
|
||||
def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]:
|
||||
"""
|
||||
Extract region, project id from GCP API url
|
||||
API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}...
|
||||
"""
|
||||
# Extract region from hostname
|
||||
hostname = parsed_api_url.netloc
|
||||
region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else ""
|
||||
|
||||
# Extract project ID from path (e.g., "/v1/projects/my-project/...")
|
||||
path_parts = parsed_api_url.path.split("/")
|
||||
project_id = ""
|
||||
for i, part in enumerate(path_parts):
|
||||
if part == "projects" and i + 1 < len(path_parts):
|
||||
project_id = path_parts[i + 1]
|
||||
break
|
||||
|
||||
return region, project_id
|
||||
|
||||
|
||||
def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo:
|
||||
"""
|
||||
Get the GCP Vertex or default AI API client info based on the API key and URL.
|
||||
"""
|
||||
region, project_id, credentials = None, None, None
|
||||
# Check if AI model to be used via GCP Vertex API
|
||||
parsed_api_url = urlparse(api_base_url)
|
||||
if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"):
|
||||
region, project_id = get_gcp_project_info(parsed_api_url)
|
||||
credentials = get_gcp_credentials(api_key)
|
||||
if credentials:
|
||||
api_key = None
|
||||
return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key)
|
||||
|
||||
|
||||
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
|
||||
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||
parsed_url = urlparse(api_base_url)
|
||||
|
||||
Reference in New Issue
Block a user