mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Cache Google AI API client for reuse
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import errors as gerrors
|
from google.genai import errors as gerrors
|
||||||
@@ -31,6 +32,7 @@ from khoj.utils.helpers import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
gemini_clients: Dict[str, genai.Client] = {}
|
||||||
|
|
||||||
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||||
SAFETY_SETTINGS = [
|
SAFETY_SETTINGS = [
|
||||||
@@ -73,7 +75,11 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
|||||||
def gemini_completion_with_backoff(
|
def gemini_completion_with_backoff(
|
||||||
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
messages, system_prompt, model_name, temperature=0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
||||||
) -> str:
|
) -> str:
|
||||||
client = get_gemini_client(api_key, api_base_url)
|
client = gemini_clients.get(api_key)
|
||||||
|
if not client:
|
||||||
|
client = get_gemini_client(api_key, api_base_url)
|
||||||
|
gemini_clients[api_key] = client
|
||||||
|
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
@@ -154,7 +160,11 @@ def gemini_llm_thread(
|
|||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
client = get_gemini_client(api_key, api_base_url)
|
client = gemini_clients.get(api_key)
|
||||||
|
if not client:
|
||||||
|
client = get_gemini_client(api_key, api_base_url)
|
||||||
|
gemini_clients[api_key] = client
|
||||||
|
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def completion_with_backoff(
|
|||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
client = openai_clients.get(client_key)
|
||||||
if not client:
|
if not client:
|
||||||
client = get_openai_client(openai_api_key, api_base_url)
|
client = get_openai_client(openai_api_key, api_base_url)
|
||||||
openai_clients[client_key] = client
|
openai_clients[client_key] = client
|
||||||
@@ -150,9 +150,8 @@ def llm_thread(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
if client_key in openai_clients:
|
client = openai_clients.get(client_key)
|
||||||
client = openai_clients[client_key]
|
if not client:
|
||||||
else:
|
|
||||||
client = get_openai_client(openai_api_key, api_base_url)
|
client = get_openai_client(openai_api_key, api_base_url)
|
||||||
openai_clients[client_key] = client
|
openai_clients[client_key] = client
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user