mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Support access to Anthropic models via GCP Vertex AI
Enable configuring a Khoj AI model API for Vertex AI using GCP credentials. Specifically use the api key & api base url fields of the AI Model API associated with the current chat model to extract gcp region, gcp project id & credentials. This helps create a AnthropicVertex client. The api key field should contain the GCP service account keyfile as a base64 encoded string. The api base url field should be of the form `https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}` Accepting GCP credentials via the AI model API makes it easy to use across local and cloud environments. As it bypasses the need for a separate service account key file on the Khoj server.
This commit is contained in:
@@ -34,6 +34,7 @@ def extract_questions_anthropic(
|
|||||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
api_key=None,
|
api_key=None,
|
||||||
|
api_base_url=None,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
@@ -102,6 +103,7 @@ def extract_questions_anthropic(
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
@@ -122,7 +124,9 @@ def extract_questions_anthropic(
|
|||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}):
|
def anthropic_send_message_to_model(
|
||||||
|
messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
@@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
@@ -148,6 +153,7 @@ def converse_anthropic(
|
|||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base_url: Optional[str] = None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
@@ -238,6 +244,7 @@ def converse_anthropic(
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
|
|||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
|
get_ai_api_info,
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
@@ -26,13 +27,25 @@ from khoj.utils.helpers import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
anthropic_clients: Dict[str, anthropic.Anthropic] = {}
|
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
||||||
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
||||||
|
|
||||||
|
|
||||||
|
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||||
|
api_info = get_ai_api_info(api_key, api_base_url)
|
||||||
|
if api_info.api_key:
|
||||||
|
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||||
|
else:
|
||||||
|
client = anthropic.AnthropicVertex(
|
||||||
|
region=api_info.region,
|
||||||
|
project_id=api_info.project,
|
||||||
|
credentials=api_info.credentials,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
@@ -45,17 +58,17 @@ def anthropic_completion_with_backoff(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
|
api_base_url: str = None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
response_type="text",
|
response_type="text",
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
tracer={},
|
tracer={},
|
||||||
) -> str:
|
) -> str:
|
||||||
if api_key not in anthropic_clients:
|
client = anthropic_clients.get(api_key)
|
||||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
if not client:
|
||||||
|
client = get_anthropic_client(api_key, api_base_url)
|
||||||
anthropic_clients[api_key] = client
|
anthropic_clients[api_key] = client
|
||||||
else:
|
|
||||||
client = anthropic_clients[api_key]
|
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
@@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff(
|
|||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
api_key,
|
api_key,
|
||||||
|
api_base_url,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
@@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff(
|
|||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
api_key,
|
api_key,
|
||||||
|
api_base_url,
|
||||||
max_prompt_size,
|
max_prompt_size,
|
||||||
deepthought,
|
deepthought,
|
||||||
model_kwargs,
|
model_kwargs,
|
||||||
@@ -149,17 +164,17 @@ def anthropic_llm_thread(
|
|||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
api_key,
|
api_key,
|
||||||
|
api_base_url=None,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if api_key not in anthropic_clients:
|
client = anthropic_clients.get(api_key)
|
||||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
if not client:
|
||||||
|
client = get_anthropic_client(api_key, api_base_url)
|
||||||
anthropic_clients[api_key] = client
|
anthropic_clients[api_key] = client
|
||||||
else:
|
|
||||||
client: anthropic.Anthropic = anthropic_clients[api_key]
|
|
||||||
|
|
||||||
model_kwargs = model_kwargs or dict()
|
model_kwargs = model_kwargs or dict()
|
||||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||||
|
|||||||
@@ -463,12 +463,14 @@ async def extract_references_and_questions(
|
|||||||
)
|
)
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_model_name = chat_model.name
|
chat_model_name = chat_model.name
|
||||||
inferred_queries = extract_questions_anthropic(
|
inferred_queries = extract_questions_anthropic(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user=user,
|
user=user,
|
||||||
|
|||||||
@@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper(
|
|||||||
)
|
)
|
||||||
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=query,
|
user_message=query,
|
||||||
context_message=context,
|
context_message=context,
|
||||||
@@ -1239,6 +1240,7 @@ async def send_message_to_model_wrapper(
|
|||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
|
api_base_url=api_base_url,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModel.ModelType.GOOGLE:
|
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||||
@@ -1342,6 +1344,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
@@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
@@ -1510,6 +1514,7 @@ def generate_chat_response(
|
|||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response = converse_anthropic(
|
chat_response = converse_anthropic(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
@@ -1519,6 +1524,7 @@ def generate_chat_response(
|
|||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
model=chat_model.name,
|
model=chat_model.name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
|
|||||||
@@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
|
|||||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||||
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
|
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
|
||||||
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
|
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
|
||||||
|
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0},
|
||||||
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
|
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
|
||||||
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
|
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
|
||||||
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
|
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
|
||||||
|
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0},
|
||||||
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
|
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations # to avoid quoting type hints
|
from __future__ import annotations # to avoid quoting type hints
|
||||||
|
|
||||||
|
import base64
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
import io
|
||||||
@@ -19,15 +20,18 @@ from itertools import islice
|
|||||||
from os import path
|
from os import path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import ParseResult, urlparse
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import psutil
|
import psutil
|
||||||
|
import pyjson5
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
from google.oauth2 import service_account
|
||||||
from magika import Magika
|
from magika import Magika
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytz import country_names, country_timezones
|
from pytz import country_names, country_timezones
|
||||||
@@ -618,6 +622,58 @@ def get_chat_usage_metrics(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AiApiInfo(NamedTuple):
|
||||||
|
region: str
|
||||||
|
project: str
|
||||||
|
credentials: Credentials
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]:
|
||||||
|
"""
|
||||||
|
Get GCP credentials from base64 encoded service account credentials json keyfile
|
||||||
|
"""
|
||||||
|
credentials_json = base64.b64decode(credentials_b64).decode("utf-8")
|
||||||
|
credentials_dict = pyjson5.loads(credentials_json)
|
||||||
|
credentials = service_account.Credentials.from_service_account_info(credentials_dict)
|
||||||
|
return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Extract region, project id from GCP API url
|
||||||
|
API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}...
|
||||||
|
"""
|
||||||
|
# Extract region from hostname
|
||||||
|
hostname = parsed_api_url.netloc
|
||||||
|
region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else ""
|
||||||
|
|
||||||
|
# Extract project ID from path (e.g., "/v1/projects/my-project/...")
|
||||||
|
path_parts = parsed_api_url.path.split("/")
|
||||||
|
project_id = ""
|
||||||
|
for i, part in enumerate(path_parts):
|
||||||
|
if part == "projects" and i + 1 < len(path_parts):
|
||||||
|
project_id = path_parts[i + 1]
|
||||||
|
break
|
||||||
|
|
||||||
|
return region, project_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo:
|
||||||
|
"""
|
||||||
|
Get the GCP Vertex or default AI API client info based on the API key and URL.
|
||||||
|
"""
|
||||||
|
region, project_id, credentials = None, None, None
|
||||||
|
# Check if AI model to be used via GCP Vertex API
|
||||||
|
parsed_api_url = urlparse(api_base_url)
|
||||||
|
if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"):
|
||||||
|
region, project_id = get_gcp_project_info(parsed_api_url)
|
||||||
|
credentials = get_gcp_credentials(api_key)
|
||||||
|
if credentials:
|
||||||
|
api_key = None
|
||||||
|
return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key)
|
||||||
|
|
||||||
|
|
||||||
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
|
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
|
||||||
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||||
parsed_url = urlparse(api_base_url)
|
parsed_url = urlparse(api_base_url)
|
||||||
|
|||||||
Reference in New Issue
Block a user