Cache Google AI API client for reuse

This commit is contained in:
Debanjum
2025-03-23 14:33:54 +05:30
parent da33c7d83c
commit 7153d27528
2 changed files with 15 additions and 6 deletions

View File

@@ -3,6 +3,7 @@ import os
import random
from copy import deepcopy
from threading import Thread
from typing import Dict
from google import genai
from google.genai import errors as gerrors
@@ -31,6 +32,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
gemini_clients: Dict[str, genai.Client] = {}
MAX_OUTPUT_TOKENS_GEMINI = 8192
SAFETY_SETTINGS = [
@@ -73,7 +75,11 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
) -> str:
client = get_gemini_client(api_key, api_base_url)
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
@@ -154,7 +160,11 @@ def gemini_llm_thread(
tracer: dict = {},
):
try:
client = get_gemini_client(api_key, api_base_url)
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,

View File

@@ -55,7 +55,7 @@ def completion_with_backoff(
tracer: dict = {},
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
client = openai_clients.get(client_key)
if not client:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
@@ -150,9 +150,8 @@ def llm_thread(
):
try:
client_key = f"{openai_api_key}--{api_base_url}"
if client_key in openai_clients:
client = openai_clients[client_key]
else:
client = openai_clients.get(client_key)
if not client:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client