diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 3b116488..b3bdd5a3 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -3,6 +3,7 @@ import os import random from copy import deepcopy from threading import Thread +from typing import Dict from google import genai from google.genai import errors as gerrors @@ -31,6 +32,7 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) +gemini_clients: Dict[str, genai.Client] = {} MAX_OUTPUT_TOKENS_GEMINI = 8192 SAFETY_SETTINGS = [ @@ -73,7 +75,11 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client: def gemini_completion_with_backoff( messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={} ) -> str: - client = get_gemini_client(api_key, api_base_url) + client = gemini_clients.get(api_key) + if not client: + client = get_gemini_client(api_key, api_base_url) + gemini_clients[api_key] = client + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, @@ -154,7 +160,11 @@ def gemini_llm_thread( tracer: dict = {}, ): try: - client = get_gemini_client(api_key, api_base_url) + client = gemini_clients.get(api_key) + if not client: + client = get_gemini_client(api_key, api_base_url) + gemini_clients[api_key] = client + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 5ca66d68..c664d882 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -55,7 +55,7 @@ def completion_with_backoff( tracer: dict = {}, ) -> str: client_key = f"{openai_api_key}--{api_base_url}" - client: openai.OpenAI | None = openai_clients.get(client_key) + client = openai_clients.get(client_key) if not client: client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client @@ -150,9 +150,8 @@ def llm_thread( ): try: client_key = f"{openai_api_key}--{api_base_url}" - if client_key in openai_clients: - client = openai_clients[client_key] - else: + client = openai_clients.get(client_key) + if not client: client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client