From 603c4bf2dfe007b1f88171f29b8a32e12ba70675 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 22 Mar 2025 12:36:05 +0530 Subject: [PATCH] Support access to Anthropic models via GCP Vertex AI Enable configuring a Khoj AI model API for Vertex AI using GCP credentials. Specifically use the api key & api base url fields of the AI Model API associated with the current chat model to extract gcp region, gcp project id & credentials. This helps create a AnthropicVertex client. The api key field should contain the GCP service account keyfile as a base64 encoded string. The api base url field should be of the form `https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}` Accepting GCP credentials via the AI model API makes it easy to use across local and cloud environments. As it bypasses the need for a separate service account key file on the Khoj server. --- .../conversation/anthropic/anthropic_chat.py | 9 ++- .../processor/conversation/anthropic/utils.py | 35 +++++++---- src/khoj/routers/api.py | 2 + src/khoj/routers/helpers.py | 6 ++ src/khoj/utils/constants.py | 2 + src/khoj/utils/helpers.py | 60 ++++++++++++++++++- 6 files changed, 101 insertions(+), 13 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 8b16ac02..01de4b16 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -34,6 +34,7 @@ def extract_questions_anthropic( model: Optional[str] = "claude-3-7-sonnet-latest", conversation_log={}, api_key=None, + api_base_url=None, temperature=0.7, location_data: LocationData = None, user: KhojUser = None, @@ -102,6 +103,7 @@ def extract_questions_anthropic( model_name=model, temperature=temperature, api_key=api_key, + api_base_url=api_base_url, response_type="json_object", tracer=tracer, ) @@ -122,7 +124,9 @@ def extract_questions_anthropic( return questions -def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}): +def anthropic_send_message_to_model( + messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={} +): """ Send message to model """ @@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex system_prompt=system_prompt, model_name=model, api_key=api_key, + api_base_url=api_base_url, response_type=response_type, deepthought=deepthought, tracer=tracer, @@ -148,6 +153,7 @@ def converse_anthropic( conversation_log={}, model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, + api_base_url: Optional[str] = None, completion_func=None, conversation_commands=[ConversationCommand.Default], max_prompt_size=None, @@ -238,6 +244,7 @@ def converse_anthropic( model_name=model, temperature=0, api_key=api_key, + api_base_url=api_base_url, system_prompt=system_prompt, completion_func=completion_func, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 8744cfe4..fac99e04 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils.helpers import ( + get_ai_api_info, get_chat_usage_metrics, is_none_or_empty, is_promptrace_enabled, @@ -26,13 +27,25 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) -anthropic_clients: Dict[str, anthropic.Anthropic] = {} - +anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {} DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 MAX_REASONING_TOKENS_ANTHROPIC = 12000 +def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex: + api_info = get_ai_api_info(api_key, api_base_url) + if api_info.api_key: + client = anthropic.Anthropic(api_key=api_info.api_key) + else: + client = anthropic.AnthropicVertex( + region=api_info.region, + project_id=api_info.project, + credentials=api_info.credentials, + ) + return client + + @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), @@ -45,17 +58,17 @@ def anthropic_completion_with_backoff( model_name: str, temperature=0, api_key=None, + api_base_url: str = None, model_kwargs=None, max_tokens=None, response_type="text", deepthought=False, tracer={}, ) -> str: - if api_key not in anthropic_clients: - client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) + client = anthropic_clients.get(api_key) + if not client: + client = get_anthropic_client(api_key, api_base_url) anthropic_clients[api_key] = client - else: - client = anthropic_clients[api_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] aggregated_response = "" @@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff( model_name, temperature, api_key, + api_base_url, system_prompt, max_prompt_size=None, completion_func=None, @@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff( model_name, temperature, api_key, + api_base_url, max_prompt_size, deepthought, model_kwargs, @@ -149,17 +164,17 @@ def anthropic_llm_thread( model_name, temperature, api_key, + api_base_url=None, max_prompt_size=None, deepthought=False, model_kwargs=None, tracer={}, ): try: - if api_key not in anthropic_clients: - client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) + client = anthropic_clients.get(api_key) + if not client: + client = get_anthropic_client(api_key, api_base_url) anthropic_clients[api_key] = client - else: - client: anthropic.Anthropic = anthropic_clients[api_key] model_kwargs = model_kwargs or dict() max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ca8287a1..98b271d9 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -463,12 +463,14 @@ async def extract_references_and_questions( ) elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url chat_model_name = chat_model.name inferred_queries = extract_questions_anthropic( defiltered_query, query_images=query_images, model=chat_model_name, api_key=api_key, + api_base_url=api_base_url, conversation_log=meta_log, location_data=location_data, user=user, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index caca0eaa..6cbcc250 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper( ) elif model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=query, context_message=context, @@ -1239,6 +1240,7 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, deepthought=deepthought, + api_base_url=api_base_url, tracer=tracer, ) elif model_type == ChatModel.ModelType.GOOGLE: @@ -1342,6 +1344,7 @@ def send_message_to_model_wrapper_sync( elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, @@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync( return anthropic_send_message_to_model( messages=truncated_messages, api_key=api_key, + api_base_url=api_base_url, model=chat_model_name, response_type=response_type, tracer=tracer, @@ -1510,6 +1514,7 @@ def generate_chat_response( elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url chat_response = converse_anthropic( compiled_references, query_to_run, @@ -1519,6 +1524,7 @@ def generate_chat_response( conversation_log=meta_log, model=chat_model.name, api_key=api_key, + api_base_url=api_base_url, completion_func=partial_completion, conversation_commands=conversation_commands, max_prompt_size=chat_model.max_prompt_size, diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index b3d14f18..4f5c1cc8 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = { "gemini-2.0-flash": {"input": 0.10, "output": 0.40}, # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, + "claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0}, "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, "claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0}, "claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0}, + "claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0}, "claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0}, } diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 6698d3bb..a3042daf 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -1,5 +1,6 @@ from __future__ import annotations # to avoid quoting type hints +import base64 import copy import datetime import io @@ -19,15 +20,18 @@ from itertools import islice from os import path from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Any, Optional, Union -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union +from urllib.parse import ParseResult, urlparse import openai import psutil +import pyjson5 import requests import torch from asgiref.sync import sync_to_async from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email +from google.auth.credentials import Credentials +from google.oauth2 import service_account from magika import Magika from PIL import Image from pytz import country_names, country_timezones @@ -618,6 +622,58 @@ def get_chat_usage_metrics( } +class AiApiInfo(NamedTuple): + region: str + project: str + credentials: Credentials + api_key: str + + +def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]: + """ + Get GCP credentials from base64 encoded service account credentials json keyfile + """ + credentials_json = base64.b64decode(credentials_b64).decode("utf-8") + credentials_dict = pyjson5.loads(credentials_json) + credentials = service_account.Credentials.from_service_account_info(credentials_dict) + return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + + +def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]: + """ + Extract region, project id from GCP API url + API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}... + """ + # Extract region from hostname + hostname = parsed_api_url.netloc + region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else "" + + # Extract project ID from path (e.g., "/v1/projects/my-project/...") + path_parts = parsed_api_url.path.split("/") + project_id = "" + for i, part in enumerate(path_parts): + if part == "projects" and i + 1 < len(path_parts): + project_id = path_parts[i + 1] + break + + return region, project_id + + +def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo: + """ + Get the GCP Vertex or default AI API client info based on the API key and URL. + """ + region, project_id, credentials = None, None, None + # Check if AI model to be used via GCP Vertex API + parsed_api_url = urlparse(api_base_url) + if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"): + region, project_id = get_gcp_project_info(parsed_api_url) + credentials = get_gcp_credentials(api_key) + if credentials: + api_key = None + return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key) + + def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]: """Get OpenAI or AzureOpenAI client based on the API Base URL""" parsed_url = urlparse(api_base_url)