mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 21:29:12 +00:00
Move message truncation logic into a separate function. Add unit tests with factory boy.
This commit is contained in:
@@ -97,23 +97,33 @@ def generate_chatml_messages_with_context(
|
||||
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_message(messages, max_prompt_size[model_name], model_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
|
||||
def truncate_message(messages, max_prompt_size, model_name):
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
while tokens > max_prompt_size[model_name] and len(messages) > 1:
|
||||
logger.info(f"num tokens: {tokens}")
|
||||
while tokens > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
|
||||
# Truncate last message if still over max supported prompt size by model
|
||||
if tokens > max_prompt_size[model_name]:
|
||||
last_message = messages[-1]
|
||||
truncated_message = encoder.decode(encoder.encode(last_message.content))
|
||||
if tokens > max_prompt_size:
|
||||
last_message = '\n'.join(messages[-1].content.split("\n")[:-1])
|
||||
original_question = '\n'.join(messages[-1].content.split("\n")[-1:])
|
||||
original_question_tokens = len(encoder.encode(original_question))
|
||||
remaining_tokens = max_prompt_size - original_question_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
||||
logger.debug(
|
||||
f"Truncate last message to fit within max prompt size of {max_prompt_size[model_name]} supported by {model_name} model:\n {truncated_message}"
|
||||
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
)
|
||||
messages = [ChatMessage(content=truncated_message, role=last_message.role)]
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
return messages
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
|
||||
Reference in New Issue
Block a user