Cache Google AI API client for reuse

This commit is contained in:
Debanjum
2025-03-23 14:33:54 +05:30
parent da33c7d83c
commit 7153d27528
2 changed files with 15 additions and 6 deletions

View File

@@ -3,6 +3,7 @@ import os
import random import random
from copy import deepcopy from copy import deepcopy
from threading import Thread from threading import Thread
from typing import Dict
from google import genai from google import genai
from google.genai import errors as gerrors from google.genai import errors as gerrors
@@ -31,6 +32,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
gemini_clients: Dict[str, genai.Client] = {}
MAX_OUTPUT_TOKENS_GEMINI = 8192 MAX_OUTPUT_TOKENS_GEMINI = 8192
SAFETY_SETTINGS = [ SAFETY_SETTINGS = [
@@ -73,7 +75,11 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
def gemini_completion_with_backoff( def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={} messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
) -> str: ) -> str:
client = get_gemini_client(api_key, api_base_url) client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig( config = gtypes.GenerateContentConfig(
system_instruction=system_prompt, system_instruction=system_prompt,
@@ -154,7 +160,11 @@ def gemini_llm_thread(
tracer: dict = {}, tracer: dict = {},
): ):
try: try:
client = get_gemini_client(api_key, api_base_url) client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig( config = gtypes.GenerateContentConfig(
system_instruction=system_prompt, system_instruction=system_prompt,

View File

@@ -55,7 +55,7 @@ def completion_with_backoff(
tracer: dict = {}, tracer: dict = {},
) -> str: ) -> str:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key) client = openai_clients.get(client_key)
if not client: if not client:
client = get_openai_client(openai_api_key, api_base_url) client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client openai_clients[client_key] = client
@@ -150,9 +150,8 @@ def llm_thread(
): ):
try: try:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
if client_key in openai_clients: client = openai_clients.get(client_key)
client = openai_clients[client_key] if not client:
else:
client = get_openai_client(openai_api_key, api_base_url) client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client openai_clients[client_key] = client