diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index cd55b0ff..84d16dbd 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -79,7 +79,6 @@ def extract_questions_gemini( model_name=model, temperature=temperature, api_key=api_key, - max_tokens=max_tokens, model_kwargs=model_kwargs, ) @@ -218,5 +217,4 @@ def converse_gemini( api_key=api_key, system_prompt=system_prompt, completion_func=completion_func, - max_prompt_size=max_prompt_size, ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 63b8b610..4af724f6 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -4,10 +4,7 @@ from threading import Thread import google.generativeai as genai from google.generativeai.types.answer_types import FinishReason -from google.generativeai.types.generation_types import ( - GenerateContentResponse, - StopCandidateException, -) +from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.safety_types import ( HarmBlockThreshold, HarmCategory, @@ -26,7 +23,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator logger = logging.getLogger(__name__) -DEFAULT_MAX_TOKENS_GEMINI = 8192 +MAX_OUTPUT_TOKENS_GEMINI = 8192 @retry( @@ -36,13 +33,12 @@ DEFAULT_MAX_TOKENS_GEMINI = 8192 reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None ) -> str: genai.configure(api_key=api_key) - max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI model_kwargs = model_kwargs or dict() model_kwargs["temperature"] = temperature - model_kwargs["max_output_tokens"] = max_tokens + model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI model = genai.GenerativeModel( model_name, generation_config=model_kwargs, @@ -88,28 +84,24 @@ def gemini_chat_completion_with_backoff( temperature, api_key, system_prompt, - max_prompt_size=None, completion_func=None, model_kwargs=None, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=gemini_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), + args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs), ) t.start() return g -def gemini_llm_thread( - g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None -): +def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None): try: genai.configure(api_key=api_key) - max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI model_kwargs = model_kwargs or dict() model_kwargs["temperature"] = temperature - model_kwargs["max_output_tokens"] = max_tokens + model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI model_kwargs["stop_sequences"] = ["Notes:\n["] model = genai.GenerativeModel( model_name,