From 452e360175240f00a973e7c577dee49eef8cab1e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 6 Oct 2024 14:01:13 -0700 Subject: [PATCH] Do not use max prompt size to limit Gemini max output tokens We should start disambiguating the the max input from output size. Max prompt size should only be used for the max input context to an LLM. If required max_output_tokens should be set as a separate new field --- .../conversation/google/gemini_chat.py | 2 -- .../processor/conversation/google/utils.py | 22 ++++++------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index cd55b0ff..84d16dbd 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -79,7 +79,6 @@ def extract_questions_gemini( model_name=model, temperature=temperature, api_key=api_key, - max_tokens=max_tokens, model_kwargs=model_kwargs, ) @@ -218,5 +217,4 @@ def converse_gemini( api_key=api_key, system_prompt=system_prompt, completion_func=completion_func, - max_prompt_size=max_prompt_size, ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 63b8b610..4af724f6 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -4,10 +4,7 @@ from threading import Thread import google.generativeai as genai from google.generativeai.types.answer_types import FinishReason -from google.generativeai.types.generation_types import ( - GenerateContentResponse, - StopCandidateException, -) +from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.safety_types import ( HarmBlockThreshold, HarmCategory, @@ -26,7 +23,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator logger = logging.getLogger(__name__) -DEFAULT_MAX_TOKENS_GEMINI = 8192 +MAX_OUTPUT_TOKENS_GEMINI = 8192 @retry( @@ -36,13 +33,12 @@ DEFAULT_MAX_TOKENS_GEMINI = 8192 reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None ) -> str: genai.configure(api_key=api_key) - max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI model_kwargs = model_kwargs or dict() model_kwargs["temperature"] = temperature - model_kwargs["max_output_tokens"] = max_tokens + model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI model = genai.GenerativeModel( model_name, generation_config=model_kwargs, @@ -88,28 +84,24 @@ def gemini_chat_completion_with_backoff( temperature, api_key, system_prompt, - max_prompt_size=None, completion_func=None, model_kwargs=None, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=gemini_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), + args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs), ) t.start() return g -def gemini_llm_thread( - g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None -): +def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None): try: genai.configure(api_key=api_key) - max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI model_kwargs = model_kwargs or dict() model_kwargs["temperature"] = temperature - model_kwargs["max_output_tokens"] = max_tokens + model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI model_kwargs["stop_sequences"] = ["Notes:\n["] model = genai.GenerativeModel( model_name,