diff --git a/documentation/docs/advanced/gcp-vertex.md b/documentation/docs/advanced/gcp-vertex.md new file mode 100644 index 00000000..7069fc33 --- /dev/null +++ b/documentation/docs/advanced/gcp-vertex.md @@ -0,0 +1,26 @@ +# Google Vertex AI +:::info +This is only helpful for self-hosted users. If you're using [Khoj Cloud](https://app.khoj.dev), you can directly use any of the pre-configured AI models. +::: + +Khoj can use Google's Gemini and Anthropic's Claude family of AI models from [Vertex AI](https://cloud.google.com/vertex-ai) on Google Cloud. Explore Anthropic and Gemini AI models available on Vertex AI's [Model Garden](https://console.cloud.google.com/vertex-ai/model-garden). + +## Setup +1. Follow [these instructions](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#before_you_begin) to use models on GCP Vertex AI. +2. Create [Service Account](https://console.cloud.google.com/apis/credentials/serviceaccountkey) credentials. + - Download the credentials keyfile in json format. + - Base64 encode the credentials json keyfile. For example by running the following command from your terminal: + `base64 -i ` +3. Create a new [API Model API](http://localhost:42110/server/admin/database/aimodelapi/add) on your Khoj admin panel. + - **Name**: `Google Vertex` (or whatever friendly name you prefer). + - **Api Key**: `base64 encoded json keyfile` from step 2. + - **Api Base Url**: `https://{MODEL_GCP_REGION}-aiplatform.googleapis.com/v1/projects/{YOUR_GCP_PROJECT_ID}` + - MODEL_GCP_REGION: A region the AI model is available in. For example `us-east5` works for [Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions). + - YOUR_GCP_PROJECT_ID: Get your project id from the [Google cloud dashboard](https://console.cloud.google.com/home/dashboard) +4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel. + - **Name**: `claude-3-7-sonnet@20250219`. Any Claude or Gemini model on Vertex's Model Garden should work. + - **Model Type**: `Anthropic` or `Google` + - **Ai Model API**: *the Google Vertex Ai Model API you created in step 3* + - **Max prompt size**: `60000` (replace with the max prompt size of your model) + - **Tokenizer**: *Do not set* +5. Select the chat model on [your settings page](http://localhost:42110/settings) and start a conversation. diff --git a/src/khoj/database/migrations/0087_alter_aimodelapi_api_key.py b/src/khoj/database/migrations/0087_alter_aimodelapi_api_key.py new file mode 100644 index 00000000..071f471f --- /dev/null +++ b/src/khoj/database/migrations/0087_alter_aimodelapi_api_key.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.13 on 2025-03-23 04:42 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0086_alter_texttoimagemodelconfig_model_type"), + ] + + operations = [ + migrations.AlterField( + model_name="aimodelapi", + name="api_key", + field=models.CharField(max_length=4000), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index f9196f80..44dcac27 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -188,7 +188,7 @@ class Subscription(DbBaseModel): class AiModelApi(DbBaseModel): name = models.CharField(max_length=200) - api_key = models.CharField(max_length=200) + api_key = models.CharField(max_length=4000) api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True) def __str__(self): diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 8b16ac02..01de4b16 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -34,6 +34,7 @@ def extract_questions_anthropic( model: Optional[str] = "claude-3-7-sonnet-latest", conversation_log={}, api_key=None, + api_base_url=None, temperature=0.7, location_data: LocationData = None, user: KhojUser = None, @@ -102,6 +103,7 @@ def extract_questions_anthropic( model_name=model, temperature=temperature, api_key=api_key, + api_base_url=api_base_url, response_type="json_object", tracer=tracer, ) @@ -122,7 +124,9 @@ def extract_questions_anthropic( return questions -def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}): +def anthropic_send_message_to_model( + messages, api_key, api_base_url, model, response_type="text", deepthought=False, tracer={} +): """ Send message to model """ @@ -134,6 +138,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex system_prompt=system_prompt, model_name=model, api_key=api_key, + api_base_url=api_base_url, response_type=response_type, deepthought=deepthought, tracer=tracer, @@ -148,6 +153,7 @@ def converse_anthropic( conversation_log={}, model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, + api_base_url: Optional[str] = None, completion_func=None, conversation_commands=[ConversationCommand.Default], max_prompt_size=None, @@ -238,6 +244,7 @@ def converse_anthropic( model_name=model, temperature=0, api_key=api_key, + api_base_url=api_base_url, system_prompt=system_prompt, completion_func=completion_func, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 8744cfe4..fac99e04 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils.helpers import ( + get_ai_api_info, get_chat_usage_metrics, is_none_or_empty, is_promptrace_enabled, @@ -26,13 +27,25 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) -anthropic_clients: Dict[str, anthropic.Anthropic] = {} - +anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {} DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 MAX_REASONING_TOKENS_ANTHROPIC = 12000 +def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex: + api_info = get_ai_api_info(api_key, api_base_url) + if api_info.api_key: + client = anthropic.Anthropic(api_key=api_info.api_key) + else: + client = anthropic.AnthropicVertex( + region=api_info.region, + project_id=api_info.project, + credentials=api_info.credentials, + ) + return client + + @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), @@ -45,17 +58,17 @@ def anthropic_completion_with_backoff( model_name: str, temperature=0, api_key=None, + api_base_url: str = None, model_kwargs=None, max_tokens=None, response_type="text", deepthought=False, tracer={}, ) -> str: - if api_key not in anthropic_clients: - client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) + client = anthropic_clients.get(api_key) + if not client: + client = get_anthropic_client(api_key, api_base_url) anthropic_clients[api_key] = client - else: - client = anthropic_clients[api_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] aggregated_response = "" @@ -115,6 +128,7 @@ def anthropic_chat_completion_with_backoff( model_name, temperature, api_key, + api_base_url, system_prompt, max_prompt_size=None, completion_func=None, @@ -132,6 +146,7 @@ def anthropic_chat_completion_with_backoff( model_name, temperature, api_key, + api_base_url, max_prompt_size, deepthought, model_kwargs, @@ -149,17 +164,17 @@ def anthropic_llm_thread( model_name, temperature, api_key, + api_base_url=None, max_prompt_size=None, deepthought=False, model_kwargs=None, tracer={}, ): try: - if api_key not in anthropic_clients: - client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) + client = anthropic_clients.get(api_key) + if not client: + client = get_anthropic_client(api_key, api_base_url) anthropic_clients[api_key] = client - else: - client: anthropic.Anthropic = anthropic_clients[api_key] model_kwargs = model_kwargs or dict() max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 7f18b079..f8df542f 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -34,6 +34,7 @@ def extract_questions_gemini( model: Optional[str] = "gemini-2.0-flash", conversation_log={}, api_key=None, + api_base_url=None, temperature=0.6, max_tokens=None, location_data: LocationData = None, @@ -97,7 +98,13 @@ def extract_questions_gemini( messages.append(ChatMessage(content=system_prompt, role="system")) response = gemini_send_message_to_model( - messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer + messages, + api_key, + model, + api_base_url=api_base_url, + response_type="json_object", + temperature=temperature, + tracer=tracer, ) # Extract, Clean Message from Gemini's Response @@ -120,6 +127,7 @@ def gemini_send_message_to_model( messages, api_key, model, + api_base_url=None, response_type="text", response_schema=None, temperature=0.6, @@ -144,6 +152,7 @@ def gemini_send_message_to_model( system_prompt=system_prompt, model_name=model, api_key=api_key, + api_base_url=api_base_url, temperature=temperature, model_kwargs=model_kwargs, tracer=tracer, @@ -158,6 +167,7 @@ def converse_gemini( conversation_log={}, model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, + api_base_url: Optional[str] = None, temperature: float = 0.6, completion_func=None, conversation_commands=[ConversationCommand.Default], @@ -249,6 +259,7 @@ def converse_gemini( model_name=model, temperature=temperature, api_key=api_key, + api_base_url=api_base_url, system_prompt=system_prompt, completion_func=completion_func, tracer=tracer, diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index c8f8c4ba..b3bdd5a3 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -3,6 +3,7 @@ import os import random from copy import deepcopy from threading import Thread +from typing import Dict from google import genai from google.genai import errors as gerrors @@ -23,6 +24,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils.helpers import ( + get_ai_api_info, get_chat_usage_metrics, is_none_or_empty, is_promptrace_enabled, @@ -30,6 +32,7 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) +gemini_clients: Dict[str, genai.Client] = {} MAX_OUTPUT_TOKENS_GEMINI = 8192 SAFETY_SETTINGS = [ @@ -52,6 +55,17 @@ SAFETY_SETTINGS = [ ] +def get_gemini_client(api_key, api_base_url=None) -> genai.Client: + api_info = get_ai_api_info(api_key, api_base_url) + return genai.Client( + location=api_info.region, + project=api_info.project, + credentials=api_info.credentials, + api_key=api_info.api_key, + vertexai=api_info.api_key is None, + ) + + @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), @@ -59,9 +73,13 @@ SAFETY_SETTINGS = [ reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={} + messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={} ) -> str: - client = genai.Client(api_key=api_key) + client = gemini_clients.get(api_key) + if not client: + client = get_gemini_client(api_key, api_base_url) + gemini_clients[api_key] = client + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, @@ -115,6 +133,7 @@ def gemini_chat_completion_with_backoff( model_name, temperature, api_key, + api_base_url, system_prompt, completion_func=None, model_kwargs=None, @@ -123,17 +142,29 @@ def gemini_chat_completion_with_backoff( g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=gemini_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer), + args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer), ) t.start() return g def gemini_llm_thread( - g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {} + g, + messages, + system_prompt, + model_name, + temperature, + api_key, + api_base_url=None, + model_kwargs=None, + tracer: dict = {}, ): try: - client = genai.Client(api_key=api_key) + client = gemini_clients.get(api_key) + if not client: + client = get_gemini_client(api_key, api_base_url) + gemini_clients[api_key] = client + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 5ca66d68..c664d882 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -55,7 +55,7 @@ def completion_with_backoff( tracer: dict = {}, ) -> str: client_key = f"{openai_api_key}--{api_base_url}" - client: openai.OpenAI | None = openai_clients.get(client_key) + client = openai_clients.get(client_key) if not client: client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client @@ -150,9 +150,8 @@ def llm_thread( ): try: client_key = f"{openai_api_key}--{api_base_url}" - if client_key in openai_clients: - client = openai_clients[client_key] - else: + client = openai_clients.get(client_key) + if not client: client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ca8287a1..01fc0a94 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -463,12 +463,14 @@ async def extract_references_and_questions( ) elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url chat_model_name = chat_model.name inferred_queries = extract_questions_anthropic( defiltered_query, query_images=query_images, model=chat_model_name, api_key=api_key, + api_base_url=api_base_url, conversation_log=meta_log, location_data=location_data, user=user, @@ -479,12 +481,14 @@ async def extract_references_and_questions( ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url chat_model_name = chat_model.name inferred_queries = extract_questions_gemini( defiltered_query, query_images=query_images, model=chat_model_name, api_key=api_key, + api_base_url=api_base_url, conversation_log=meta_log, location_data=location_data, max_tokens=chat_model.max_prompt_size, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index caca0eaa..cf7dd582 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1220,6 +1220,7 @@ async def send_message_to_model_wrapper( ) elif model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=query, context_message=context, @@ -1239,10 +1240,12 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, deepthought=deepthought, + api_base_url=api_base_url, tracer=tracer, ) elif model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=query, context_message=context, @@ -1262,6 +1265,7 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, response_schema=response_schema, + api_base_url=api_base_url, tracer=tracer, ) else: @@ -1328,7 +1332,7 @@ def send_message_to_model_wrapper_sync( query_files=query_files, ) - openai_response = send_message_to_model( + return send_message_to_model( messages=truncated_messages, api_key=api_key, api_base_url=api_base_url, @@ -1338,10 +1342,9 @@ def send_message_to_model_wrapper_sync( tracer=tracer, ) - return openai_response - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, @@ -1356,6 +1359,7 @@ def send_message_to_model_wrapper_sync( return anthropic_send_message_to_model( messages=truncated_messages, api_key=api_key, + api_base_url=api_base_url, model=chat_model_name, response_type=response_type, tracer=tracer, @@ -1363,6 +1367,7 @@ def send_message_to_model_wrapper_sync( elif chat_model.model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, @@ -1377,6 +1382,7 @@ def send_message_to_model_wrapper_sync( return gemini_send_message_to_model( messages=truncated_messages, api_key=api_key, + api_base_url=api_base_url, model=chat_model_name, response_type=response_type, response_schema=response_schema, @@ -1510,6 +1516,7 @@ def generate_chat_response( elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url chat_response = converse_anthropic( compiled_references, query_to_run, @@ -1519,6 +1526,7 @@ def generate_chat_response( conversation_log=meta_log, model=chat_model.name, api_key=api_key, + api_base_url=api_base_url, completion_func=partial_completion, conversation_commands=conversation_commands, max_prompt_size=chat_model.max_prompt_size, @@ -1536,6 +1544,7 @@ def generate_chat_response( ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url chat_response = converse_gemini( compiled_references, query_to_run, @@ -1544,6 +1553,7 @@ def generate_chat_response( meta_log, model=chat_model.name, api_key=api_key, + api_base_url=api_base_url, completion_func=partial_completion, conversation_commands=conversation_commands, max_prompt_size=chat_model.max_prompt_size, diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index b3d14f18..4f5c1cc8 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -49,8 +49,10 @@ model_to_cost: Dict[str, Dict[str, float]] = { "gemini-2.0-flash": {"input": 0.10, "output": 0.40}, # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, + "claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0}, "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, "claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0}, "claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0}, + "claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0}, "claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0}, } diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 6698d3bb..a3042daf 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -1,5 +1,6 @@ from __future__ import annotations # to avoid quoting type hints +import base64 import copy import datetime import io @@ -19,15 +20,18 @@ from itertools import islice from os import path from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Any, Optional, Union -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union +from urllib.parse import ParseResult, urlparse import openai import psutil +import pyjson5 import requests import torch from asgiref.sync import sync_to_async from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email +from google.auth.credentials import Credentials +from google.oauth2 import service_account from magika import Magika from PIL import Image from pytz import country_names, country_timezones @@ -618,6 +622,58 @@ def get_chat_usage_metrics( } +class AiApiInfo(NamedTuple): + region: str + project: str + credentials: Credentials + api_key: str + + +def get_gcp_credentials(credentials_b64: str) -> Optional[Credentials]: + """ + Get GCP credentials from base64 encoded service account credentials json keyfile + """ + credentials_json = base64.b64decode(credentials_b64).decode("utf-8") + credentials_dict = pyjson5.loads(credentials_json) + credentials = service_account.Credentials.from_service_account_info(credentials_dict) + return credentials.with_scopes(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + + +def get_gcp_project_info(parsed_api_url: ParseResult) -> Tuple[str, str]: + """ + Extract region, project id from GCP API url + API url is of form https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}... + """ + # Extract region from hostname + hostname = parsed_api_url.netloc + region = hostname.split("-aiplatform")[0] if "-aiplatform" in hostname else "" + + # Extract project ID from path (e.g., "/v1/projects/my-project/...") + path_parts = parsed_api_url.path.split("/") + project_id = "" + for i, part in enumerate(path_parts): + if part == "projects" and i + 1 < len(path_parts): + project_id = path_parts[i + 1] + break + + return region, project_id + + +def get_ai_api_info(api_key, api_base_url: str = None) -> AiApiInfo: + """ + Get the GCP Vertex or default AI API client info based on the API key and URL. + """ + region, project_id, credentials = None, None, None + # Check if AI model to be used via GCP Vertex API + parsed_api_url = urlparse(api_base_url) + if parsed_api_url.hostname and parsed_api_url.hostname.endswith(".googleapis.com"): + region, project_id = get_gcp_project_info(parsed_api_url) + credentials = get_gcp_credentials(api_key) + if credentials: + api_key = None + return AiApiInfo(region=region, project=project_id, credentials=credentials, api_key=api_key) + + def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]: """Get OpenAI or AzureOpenAI client based on the API Base URL""" parsed_url = urlparse(api_base_url)