mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Cache Google AI API client for reuse
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from threading import Thread
|
||||
from typing import Dict
|
||||
|
||||
from google import genai
|
||||
from google.genai import errors as gerrors
|
||||
@@ -31,6 +32,7 @@ from khoj.utils.helpers import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
gemini_clients: Dict[str, genai.Client] = {}
|
||||
|
||||
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||
SAFETY_SETTINGS = [
|
||||
@@ -73,7 +75,11 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
||||
) -> str:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
@@ -154,7 +160,11 @@ def gemini_llm_thread(
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
|
||||
@@ -55,7 +55,7 @@ def completion_with_backoff(
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||
client = openai_clients.get(client_key)
|
||||
if not client:
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
@@ -150,9 +150,8 @@ def llm_thread(
|
||||
):
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
if client_key in openai_clients:
|
||||
client = openai_clients[client_key]
|
||||
else:
|
||||
client = openai_clients.get(client_key)
|
||||
if not client:
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
|
||||
|
||||
Reference in New Issue
Block a user