Support access to Gemini models via GCP Vertex AI

This commit is contained in:
Debanjum
2025-03-23 14:22:50 +05:30
parent 603c4bf2df
commit da33c7d83c
4 changed files with 47 additions and 9 deletions

View File

@@ -34,6 +34,7 @@ def extract_questions_gemini(
model: Optional[str] = "gemini-2.0-flash",
conversation_log={},
api_key=None,
api_base_url=None,
temperature=0.6,
max_tokens=None,
location_data: LocationData = None,
@@ -97,7 +98,13 @@ def extract_questions_gemini(
messages.append(ChatMessage(content=system_prompt, role="system"))
response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
messages,
api_key,
model,
api_base_url=api_base_url,
response_type="json_object",
temperature=temperature,
tracer=tracer,
)
# Extract, Clean Message from Gemini's Response
@@ -120,6 +127,7 @@ def gemini_send_message_to_model(
messages,
api_key,
model,
api_base_url=None,
response_type="text",
response_schema=None,
temperature=0.6,
@@ -144,6 +152,7 @@ def gemini_send_message_to_model(
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
api_base_url=api_base_url,
temperature=temperature,
model_kwargs=model_kwargs,
tracer=tracer,
@@ -158,6 +167,7 @@ def converse_gemini(
conversation_log={},
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.6,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
@@ -249,6 +259,7 @@ def converse_gemini(
model_name=model,
temperature=temperature,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
tracer=tracer,

View File

@@ -23,6 +23,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics,
is_none_or_empty,
is_promptrace_enabled,
@@ -52,6 +53,17 @@ SAFETY_SETTINGS = [
]
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
@@ -59,9 +71,9 @@ SAFETY_SETTINGS = [
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
) -> str:
client = genai.Client(api_key=api_key)
client = get_gemini_client(api_key, api_base_url)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
@@ -115,6 +127,7 @@ def gemini_chat_completion_with_backoff(
model_name,
temperature,
api_key,
api_base_url,
system_prompt,
completion_func=None,
model_kwargs=None,
@@ -123,17 +136,25 @@ def gemini_chat_completion_with_backoff(
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
)
t.start()
return g
def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
):
try:
client = genai.Client(api_key=api_key)
client = get_gemini_client(api_key, api_base_url)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,

View File

@@ -481,12 +481,14 @@ async def extract_references_and_questions(
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log,
location_data=location_data,
max_tokens=chat_model.max_prompt_size,

View File

@@ -1245,6 +1245,7 @@ async def send_message_to_model_wrapper(
)
elif model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
@@ -1264,6 +1265,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
api_base_url=api_base_url,
tracer=tracer,
)
else:
@@ -1330,7 +1332,7 @@ def send_message_to_model_wrapper_sync(
query_files=query_files,
)
openai_response = send_message_to_model(
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
@@ -1340,8 +1342,6 @@ def send_message_to_model_wrapper_sync(
tracer=tracer,
)
return openai_response
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
@@ -1367,6 +1367,7 @@ def send_message_to_model_wrapper_sync(
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
@@ -1381,6 +1382,7 @@ def send_message_to_model_wrapper_sync(
return gemini_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
@@ -1542,6 +1544,7 @@ def generate_chat_response(
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_gemini(
compiled_references,
query_to_run,
@@ -1550,6 +1553,7 @@ def generate_chat_response(
meta_log,
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size,