diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 7f18b079..f8df542f 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -34,6 +34,7 @@ def extract_questions_gemini( model: Optional[str] = "gemini-2.0-flash", conversation_log={}, api_key=None, + api_base_url=None, temperature=0.6, max_tokens=None, location_data: LocationData = None, @@ -97,7 +98,13 @@ def extract_questions_gemini( messages.append(ChatMessage(content=system_prompt, role="system")) response = gemini_send_message_to_model( - messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer + messages, + api_key, + model, + api_base_url=api_base_url, + response_type="json_object", + temperature=temperature, + tracer=tracer, ) # Extract, Clean Message from Gemini's Response @@ -120,6 +127,7 @@ def gemini_send_message_to_model( messages, api_key, model, + api_base_url=None, response_type="text", response_schema=None, temperature=0.6, @@ -144,6 +152,7 @@ def gemini_send_message_to_model( system_prompt=system_prompt, model_name=model, api_key=api_key, + api_base_url=api_base_url, temperature=temperature, model_kwargs=model_kwargs, tracer=tracer, @@ -158,6 +167,7 @@ def converse_gemini( conversation_log={}, model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, + api_base_url: Optional[str] = None, temperature: float = 0.6, completion_func=None, conversation_commands=[ConversationCommand.Default], @@ -249,6 +259,7 @@ def converse_gemini( model_name=model, temperature=temperature, api_key=api_key, + api_base_url=api_base_url, system_prompt=system_prompt, completion_func=completion_func, tracer=tracer, diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index c8f8c4ba..3b116488 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -23,6 +23,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils.helpers import ( + get_ai_api_info, get_chat_usage_metrics, is_none_or_empty, is_promptrace_enabled, @@ -52,6 +53,17 @@ SAFETY_SETTINGS = [ ] +def get_gemini_client(api_key, api_base_url=None) -> genai.Client: + api_info = get_ai_api_info(api_key, api_base_url) + return genai.Client( + location=api_info.region, + project=api_info.project, + credentials=api_info.credentials, + api_key=api_info.api_key, + vertexai=api_info.api_key is None, + ) + + @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), @@ -59,9 +71,9 @@ SAFETY_SETTINGS = [ reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={} + messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={} ) -> str: - client = genai.Client(api_key=api_key) + client = get_gemini_client(api_key, api_base_url) seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, @@ -115,6 +127,7 @@ def gemini_chat_completion_with_backoff( model_name, temperature, api_key, + api_base_url, system_prompt, completion_func=None, model_kwargs=None, @@ -123,17 +136,25 @@ def gemini_chat_completion_with_backoff( g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=gemini_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer), + args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer), ) t.start() return g def gemini_llm_thread( - g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {} + g, + messages, + system_prompt, + model_name, + temperature, + api_key, + api_base_url=None, + model_kwargs=None, + tracer: dict = {}, ): try: - client = genai.Client(api_key=api_key) + client = get_gemini_client(api_key, api_base_url) seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 98b271d9..01fc0a94 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -481,12 +481,14 @@ async def extract_references_and_questions( ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url chat_model_name = chat_model.name inferred_queries = extract_questions_gemini( defiltered_query, query_images=query_images, model=chat_model_name, api_key=api_key, + api_base_url=api_base_url, conversation_log=meta_log, location_data=location_data, max_tokens=chat_model.max_prompt_size, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6cbcc250..cf7dd582 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1245,6 +1245,7 @@ async def send_message_to_model_wrapper( ) elif model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=query, context_message=context, @@ -1264,6 +1265,7 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, response_schema=response_schema, + api_base_url=api_base_url, tracer=tracer, ) else: @@ -1330,7 +1332,7 @@ def send_message_to_model_wrapper_sync( query_files=query_files, ) - openai_response = send_message_to_model( + return send_message_to_model( messages=truncated_messages, api_key=api_key, api_base_url=api_base_url, @@ -1340,8 +1342,6 @@ def send_message_to_model_wrapper_sync( tracer=tracer, ) - return openai_response - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url @@ -1367,6 +1367,7 @@ def send_message_to_model_wrapper_sync( elif chat_model.model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, @@ -1381,6 +1382,7 @@ def send_message_to_model_wrapper_sync( return gemini_send_message_to_model( messages=truncated_messages, api_key=api_key, + api_base_url=api_base_url, model=chat_model_name, response_type=response_type, response_schema=response_schema, @@ -1542,6 +1544,7 @@ def generate_chat_response( ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url chat_response = converse_gemini( compiled_references, query_to_run, @@ -1550,6 +1553,7 @@ def generate_chat_response( meta_log, model=chat_model.name, api_key=api_key, + api_base_url=api_base_url, completion_func=partial_completion, conversation_commands=conversation_commands, max_prompt_size=chat_model.max_prompt_size,