mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Do not use max prompt size to limit Gemini max output tokens
We should start disambiguating the the max input from output size. Max prompt size should only be used for the max input context to an LLM. If required max_output_tokens should be set as a separate new field
This commit is contained in:
@@ -79,7 +79,6 @@ def extract_questions_gemini(
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
@@ -218,5 +217,4 @@ def converse_gemini(
|
||||
api_key=api_key,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
)
|
||||
|
||||
@@ -4,10 +4,7 @@ from threading import Thread
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.generativeai.types.answer_types import FinishReason
|
||||
from google.generativeai.types.generation_types import (
|
||||
GenerateContentResponse,
|
||||
StopCandidateException,
|
||||
)
|
||||
from google.generativeai.types.generation_types import StopCandidateException
|
||||
from google.generativeai.types.safety_types import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
@@ -26,7 +23,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_MAX_TOKENS_GEMINI = 8192
|
||||
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -36,13 +33,12 @@ DEFAULT_MAX_TOKENS_GEMINI = 8192
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
|
||||
) -> str:
|
||||
genai.configure(api_key=api_key)
|
||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI
|
||||
model_kwargs = model_kwargs or dict()
|
||||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["max_output_tokens"] = max_tokens
|
||||
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
|
||||
model = genai.GenerativeModel(
|
||||
model_name,
|
||||
generation_config=model_kwargs,
|
||||
@@ -88,28 +84,24 @@ def gemini_chat_completion_with_backoff(
|
||||
temperature,
|
||||
api_key,
|
||||
system_prompt,
|
||||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
||||
):
|
||||
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI
|
||||
model_kwargs = model_kwargs or dict()
|
||||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["max_output_tokens"] = max_tokens
|
||||
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
|
||||
model_kwargs["stop_sequences"] = ["Notes:\n["]
|
||||
model = genai.GenerativeModel(
|
||||
model_name,
|
||||
|
||||
Reference in New Issue
Block a user