Handle truncation when single long non-system chat message

Previously was assuming the system prompt is being always passed as
the first message. So expected there to be at least 2 messages in logs.

This broke chat actors querying with single long non system message.

A more robust way to extract system prompt is via the message role
instead
This commit is contained in:
Debanjum Singh Solanky
2024-03-15 14:52:29 +05:30
parent ec0c35b7ed
commit ecddf98430
2 changed files with 67 additions and 29 deletions

View File

@@ -199,19 +199,26 @@ def truncate_messages(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
system_message = messages.pop()
assert type(system_message.content) == str
system_message_tokens = len(encoder.encode(system_message.content))
# Extract system message from messages
system_message = None
for idx, message in enumerate(messages):
if message.role == "system":
system_message = messages.pop(idx)
break
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Drop older messages until under max supported prompt size by model
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
assert type(system_message.content) == str
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
assert type(system_message.content) == str
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question = f"\n{original_question}"
@@ -223,7 +230,7 @@ def truncate_messages(
)
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages + [system_message]
return messages + [system_message] if system_message else messages
def reciprocal_conversation_to_chatml(message_pair):