Handle msg truncation when question is larger than max prompt size

Notice and truncate the question it self at this point
This commit is contained in:
Debanjum Singh Solanky
2024-03-31 15:37:29 +05:30
parent c6487f2e48
commit 4228965c9b
4 changed files with 29 additions and 8 deletions

View File

@@ -21,6 +21,7 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
# Check if the model is already downloaded
model_path = load_model_from_cache(repo_id, filename)
chat_model = None
try:
if model_path:
chat_model = Llama(model_path, **kwargs)

View File

@@ -101,8 +101,3 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_
chat(messages=messages)
g.close()
def extract_summaries(metadata):
"""Extract summaries from metadata"""
return "".join([f'\n{session["summary"]}' for session in metadata])

View File

@@ -232,12 +232,17 @@ def truncate_messages(
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question = f"\n{original_question}"
original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
remaining_tokens = max_prompt_size - system_message_tokens
if remaining_tokens > original_question_tokens:
remaining_tokens -= original_question_tokens
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
else:
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
messages = [ChatMessage(content=truncated_message, role=messages[0].role)]
logger.debug(
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
)
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages + [system_message] if system_message else messages