Support access to Gemini models via GCP Vertex AI

This commit is contained in:
Debanjum
2025-03-23 14:22:50 +05:30
parent 603c4bf2df
commit da33c7d83c
4 changed files with 47 additions and 9 deletions

View File

@@ -34,6 +34,7 @@ def extract_questions_gemini(
model: Optional[str] = "gemini-2.0-flash", model: Optional[str] = "gemini-2.0-flash",
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
api_base_url=None,
temperature=0.6, temperature=0.6,
max_tokens=None, max_tokens=None,
location_data: LocationData = None, location_data: LocationData = None,
@@ -97,7 +98,13 @@ def extract_questions_gemini(
messages.append(ChatMessage(content=system_prompt, role="system")) messages.append(ChatMessage(content=system_prompt, role="system"))
response = gemini_send_message_to_model( response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer messages,
api_key,
model,
api_base_url=api_base_url,
response_type="json_object",
temperature=temperature,
tracer=tracer,
) )
# Extract, Clean Message from Gemini's Response # Extract, Clean Message from Gemini's Response
@@ -120,6 +127,7 @@ def gemini_send_message_to_model(
messages, messages,
api_key, api_key,
model, model,
api_base_url=None,
response_type="text", response_type="text",
response_schema=None, response_schema=None,
temperature=0.6, temperature=0.6,
@@ -144,6 +152,7 @@ def gemini_send_message_to_model(
system_prompt=system_prompt, system_prompt=system_prompt,
model_name=model, model_name=model,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
temperature=temperature, temperature=temperature,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
tracer=tracer, tracer=tracer,
@@ -158,6 +167,7 @@ def converse_gemini(
conversation_log={}, conversation_log={},
model: Optional[str] = "gemini-2.0-flash", model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.6, temperature: float = 0.6,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
@@ -249,6 +259,7 @@ def converse_gemini(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
tracer=tracer, tracer=tracer,

View File

@@ -23,6 +23,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics, get_chat_usage_metrics,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
@@ -52,6 +53,17 @@ SAFETY_SETTINGS = [
] ]
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
@retry( @retry(
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),
@@ -59,9 +71,9 @@ SAFETY_SETTINGS = [
reraise=True, reraise=True,
) )
def gemini_completion_with_backoff( def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={} messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
) -> str: ) -> str:
client = genai.Client(api_key=api_key) client = get_gemini_client(api_key, api_base_url)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig( config = gtypes.GenerateContentConfig(
system_instruction=system_prompt, system_instruction=system_prompt,
@@ -115,6 +127,7 @@ def gemini_chat_completion_with_backoff(
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url,
system_prompt, system_prompt,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
@@ -123,17 +136,25 @@ def gemini_chat_completion_with_backoff(
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=gemini_llm_thread, target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer), args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
) )
t.start() t.start()
return g return g
def gemini_llm_thread( def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {} g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
): ):
try: try:
client = genai.Client(api_key=api_key) client = get_gemini_client(api_key, api_base_url)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig( config = gtypes.GenerateContentConfig(
system_instruction=system_prompt, system_instruction=system_prompt,

View File

@@ -481,12 +481,14 @@ async def extract_references_and_questions(
) )
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini( inferred_queries = extract_questions_gemini(
defiltered_query, defiltered_query,
query_images=query_images, query_images=query_images,
model=chat_model_name, model=chat_model_name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
max_tokens=chat_model.max_prompt_size, max_tokens=chat_model.max_prompt_size,

View File

@@ -1245,6 +1245,7 @@ async def send_message_to_model_wrapper(
) )
elif model_type == ChatModel.ModelType.GOOGLE: elif model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=query, user_message=query,
context_message=context, context_message=context,
@@ -1264,6 +1265,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
response_schema=response_schema, response_schema=response_schema,
api_base_url=api_base_url,
tracer=tracer, tracer=tracer,
) )
else: else:
@@ -1330,7 +1332,7 @@ def send_message_to_model_wrapper_sync(
query_files=query_files, query_files=query_files,
) )
openai_response = send_message_to_model( return send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url, api_base_url=api_base_url,
@@ -1340,8 +1342,6 @@ def send_message_to_model_wrapper_sync(
tracer=tracer, tracer=tracer,
) )
return openai_response
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url api_base_url = chat_model.ai_model_api.api_base_url
@@ -1367,6 +1367,7 @@ def send_message_to_model_wrapper_sync(
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
@@ -1381,6 +1382,7 @@ def send_message_to_model_wrapper_sync(
return gemini_send_message_to_model( return gemini_send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
response_schema=response_schema, response_schema=response_schema,
@@ -1542,6 +1544,7 @@ def generate_chat_response(
) )
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_gemini( chat_response = converse_gemini(
compiled_references, compiled_references,
query_to_run, query_to_run,
@@ -1550,6 +1553,7 @@ def generate_chat_response(
meta_log, meta_log,
model=chat_model.name, model=chat_model.name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,