Do not use max prompt size to limit Gemini max output tokens

We should start disambiguating the the max input from output size. Max
prompt size should only be used for the max input context to an LLM.

If required max_output_tokens should be set as a separate new field
This commit is contained in:
Debanjum Singh Solanky
2024-10-06 14:01:13 -07:00
parent bdc36fec5d
commit 452e360175
2 changed files with 7 additions and 17 deletions

View File

@@ -79,7 +79,6 @@ def extract_questions_gemini(
model_name=model,
temperature=temperature,
api_key=api_key,
max_tokens=max_tokens,
model_kwargs=model_kwargs,
)
@@ -218,5 +217,4 @@ def converse_gemini(
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
)

View File

@@ -4,10 +4,7 @@ from threading import Thread
import google.generativeai as genai
from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import (
GenerateContentResponse,
StopCandidateException,
)
from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import (
HarmBlockThreshold,
HarmCategory,
@@ -26,7 +23,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator
logger = logging.getLogger(__name__)
DEFAULT_MAX_TOKENS_GEMINI = 8192
MAX_OUTPUT_TOKENS_GEMINI = 8192
@retry(
@@ -36,13 +33,12 @@ DEFAULT_MAX_TOKENS_GEMINI = 8192
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
) -> str:
genai.configure(api_key=api_key)
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = max_tokens
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
model = genai.GenerativeModel(
model_name,
generation_config=model_kwargs,
@@ -88,28 +84,24 @@ def gemini_chat_completion_with_backoff(
temperature,
api_key,
system_prompt,
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
)
t.start()
return g
def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
):
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
try:
genai.configure(api_key=api_key)
max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = max_tokens
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
model_kwargs["stop_sequences"] = ["Notes:\n["]
model = genai.GenerativeModel(
model_name,