mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Set seed for Google Gemini models using KHOJ_LLM_SEED env variable
This env var was already being used to set seed for OpenAI and Offline models
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@@ -61,6 +62,7 @@ def gemini_completion_with_backoff(
|
|||||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||||
) -> str:
|
) -> str:
|
||||||
client = genai.Client(api_key=api_key)
|
client = genai.Client(api_key=api_key)
|
||||||
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -68,6 +70,7 @@ def gemini_completion_with_backoff(
|
|||||||
safety_settings=SAFETY_SETTINGS,
|
safety_settings=SAFETY_SETTINGS,
|
||||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||||
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
|
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||||
@@ -131,12 +134,14 @@ def gemini_llm_thread(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
client = genai.Client(api_key=api_key)
|
client = genai.Client(api_key=api_key)
|
||||||
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||||
stop_sequences=["Notes:\n["],
|
stop_sequences=["Notes:\n["],
|
||||||
safety_settings=SAFETY_SETTINGS,
|
safety_settings=SAFETY_SETTINGS,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
|
|||||||
Reference in New Issue
Block a user