mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 05:39:06 +00:00
Support access to Gemini models via GCP Vertex AI
This commit is contained in:
@@ -34,6 +34,7 @@ def extract_questions_gemini(
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
conversation_log={},
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
temperature=0.6,
|
||||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
@@ -97,7 +98,13 @@ def extract_questions_gemini(
|
||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
api_base_url=api_base_url,
|
||||
response_type="json_object",
|
||||
temperature=temperature,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
@@ -120,6 +127,7 @@ def gemini_send_message_to_model(
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
api_base_url=None,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
temperature=0.6,
|
||||
@@ -144,6 +152,7 @@ def gemini_send_message_to_model(
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
tracer=tracer,
|
||||
@@ -158,6 +167,7 @@ def converse_gemini(
|
||||
conversation_log={},
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 0.6,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
@@ -249,6 +259,7 @@ def converse_gemini(
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
tracer=tracer,
|
||||
|
||||
@@ -23,6 +23,7 @@ from khoj.processor.conversation.utils import (
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_ai_api_info,
|
||||
get_chat_usage_metrics,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
@@ -52,6 +53,17 @@ SAFETY_SETTINGS = [
|
||||
]
|
||||
|
||||
|
||||
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
return genai.Client(
|
||||
location=api_info.region,
|
||||
project=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
api_key=api_info.api_key,
|
||||
vertexai=api_info.api_key is None,
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
@@ -59,9 +71,9 @@ SAFETY_SETTINGS = [
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
||||
) -> str:
|
||||
client = genai.Client(api_key=api_key)
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
@@ -115,6 +127,7 @@ def gemini_chat_completion_with_backoff(
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
system_prompt,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
@@ -123,17 +136,25 @@ def gemini_chat_completion_with_backoff(
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||
g,
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
|
||||
@@ -481,12 +481,14 @@ async def extract_references_and_questions(
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
max_tokens=chat_model.max_prompt_size,
|
||||
|
||||
@@ -1245,6 +1245,7 @@ async def send_message_to_model_wrapper(
|
||||
)
|
||||
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
@@ -1264,6 +1265,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
@@ -1330,7 +1332,7 @@ def send_message_to_model_wrapper_sync(
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
@@ -1340,8 +1342,6 @@ def send_message_to_model_wrapper_sync(
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
@@ -1367,6 +1367,7 @@ def send_message_to_model_wrapper_sync(
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
@@ -1381,6 +1382,7 @@ def send_message_to_model_wrapper_sync(
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
@@ -1542,6 +1544,7 @@ def generate_chat_response(
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
@@ -1550,6 +1553,7 @@ def generate_chat_response(
|
||||
meta_log,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
|
||||
Reference in New Issue
Block a user