Set seed for Google Gemini models using KHOJ_LLM_SEED env variable

This env var was already being used to set seed for OpenAI and Offline
models
This commit is contained in:
Debanjum
2025-03-22 08:59:31 +05:30
parent 6cc5a10b09
commit 5fff05add3

View File

@@ -1,4 +1,5 @@
import logging
import os
import random
from copy import deepcopy
from threading import Thread
@@ -61,6 +62,7 @@ def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
) -> str:
client = genai.Client(api_key=api_key)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
temperature=temperature,
@@ -68,6 +70,7 @@ def gemini_completion_with_backoff(
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
seed=seed,
)
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
@@ -131,12 +134,14 @@ def gemini_llm_thread(
):
try:
client = genai.Client(api_key=api_key)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
temperature=temperature,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS,
seed=seed,
)
aggregated_response = ""