mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Support access to Gemini models via GCP Vertex AI
This commit is contained in:
@@ -34,6 +34,7 @@ def extract_questions_gemini(
|
|||||||
model: Optional[str] = "gemini-2.0-flash",
|
model: Optional[str] = "gemini-2.0-flash",
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
api_key=None,
|
api_key=None,
|
||||||
|
api_base_url=None,
|
||||||
temperature=0.6,
|
temperature=0.6,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
@@ -97,7 +98,13 @@ def extract_questions_gemini(
|
|||||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||||
|
|
||||||
response = gemini_send_message_to_model(
|
response = gemini_send_message_to_model(
|
||||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
messages,
|
||||||
|
api_key,
|
||||||
|
model,
|
||||||
|
api_base_url=api_base_url,
|
||||||
|
response_type="json_object",
|
||||||
|
temperature=temperature,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from Gemini's Response
|
# Extract, Clean Message from Gemini's Response
|
||||||
@@ -120,6 +127,7 @@ def gemini_send_message_to_model(
|
|||||||
messages,
|
messages,
|
||||||
api_key,
|
api_key,
|
||||||
model,
|
model,
|
||||||
|
api_base_url=None,
|
||||||
response_type="text",
|
response_type="text",
|
||||||
response_schema=None,
|
response_schema=None,
|
||||||
temperature=0.6,
|
temperature=0.6,
|
||||||
@@ -144,6 +152,7 @@ def gemini_send_message_to_model(
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
@@ -158,6 +167,7 @@ def converse_gemini(
|
|||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: Optional[str] = "gemini-2.0-flash",
|
model: Optional[str] = "gemini-2.0-flash",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
@@ -249,6 +259,7 @@ def converse_gemini(
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from khoj.processor.conversation.utils import (
|
|||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
|
get_ai_api_info,
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
@@ -52,6 +53,17 @@ SAFETY_SETTINGS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||||
|
api_info = get_ai_api_info(api_key, api_base_url)
|
||||||
|
return genai.Client(
|
||||||
|
location=api_info.region,
|
||||||
|
project=api_info.project,
|
||||||
|
credentials=api_info.credentials,
|
||||||
|
api_key=api_info.api_key,
|
||||||
|
vertexai=api_info.api_key is None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
@@ -59,9 +71,9 @@ SAFETY_SETTINGS = [
|
|||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def gemini_completion_with_backoff(
|
def gemini_completion_with_backoff(
|
||||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
||||||
) -> str:
|
) -> str:
|
||||||
client = genai.Client(api_key=api_key)
|
client = get_gemini_client(api_key, api_base_url)
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
@@ -115,6 +127,7 @@ def gemini_chat_completion_with_backoff(
|
|||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
api_key,
|
api_key,
|
||||||
|
api_base_url,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
@@ -123,17 +136,25 @@ def gemini_chat_completion_with_backoff(
|
|||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
t = Thread(
|
t = Thread(
|
||||||
target=gemini_llm_thread,
|
target=gemini_llm_thread,
|
||||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
|
args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def gemini_llm_thread(
|
def gemini_llm_thread(
|
||||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
g,
|
||||||
|
messages,
|
||||||
|
system_prompt,
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
api_key,
|
||||||
|
api_base_url=None,
|
||||||
|
model_kwargs=None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
client = genai.Client(api_key=api_key)
|
client = get_gemini_client(api_key, api_base_url)
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
|
|||||||
@@ -481,12 +481,14 @@ async def extract_references_and_questions(
|
|||||||
)
|
)
|
||||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_model_name = chat_model.name
|
chat_model_name = chat_model.name
|
||||||
inferred_queries = extract_questions_gemini(
|
inferred_queries = extract_questions_gemini(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
max_tokens=chat_model.max_prompt_size,
|
max_tokens=chat_model.max_prompt_size,
|
||||||
|
|||||||
@@ -1245,6 +1245,7 @@ async def send_message_to_model_wrapper(
|
|||||||
)
|
)
|
||||||
elif model_type == ChatModel.ModelType.GOOGLE:
|
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=query,
|
user_message=query,
|
||||||
context_message=context,
|
context_message=context,
|
||||||
@@ -1264,6 +1265,7 @@ async def send_message_to_model_wrapper(
|
|||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
response_schema=response_schema,
|
response_schema=response_schema,
|
||||||
|
api_base_url=api_base_url,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1330,7 +1332,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
return send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
@@ -1340,8 +1342,6 @@ def send_message_to_model_wrapper_sync(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_response
|
|
||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
@@ -1367,6 +1367,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
@@ -1381,6 +1382,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
response_schema=response_schema,
|
response_schema=response_schema,
|
||||||
@@ -1542,6 +1544,7 @@ def generate_chat_response(
|
|||||||
)
|
)
|
||||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response = converse_gemini(
|
chat_response = converse_gemini(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
@@ -1550,6 +1553,7 @@ def generate_chat_response(
|
|||||||
meta_log,
|
meta_log,
|
||||||
model=chat_model.name,
|
model=chat_model.name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user