mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Handle truncation when single long non-system chat message
Previously was assuming the system prompt is being always passed as the first message. So expected there to be at least 2 messages in logs. This broke chat actors querying with single long non system message. A more robust way to extract system prompt is via the message role instead
This commit is contained in:
@@ -199,19 +199,26 @@ def truncate_messages(
|
||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||
)
|
||||
|
||||
system_message = messages.pop()
|
||||
assert type(system_message.content) == str
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
# Extract system message from messages
|
||||
system_message = None
|
||||
for idx, message in enumerate(messages):
|
||||
if message.role == "system":
|
||||
system_message = messages.pop(idx)
|
||||
break
|
||||
|
||||
system_message_tokens = (
|
||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||
)
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||
|
||||
# Drop older messages until under max supported prompt size by model
|
||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
assert type(system_message.content) == str
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
if (tokens + system_message_tokens) > max_prompt_size:
|
||||
assert type(system_message.content) == str
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||
original_question = f"\n{original_question}"
|
||||
@@ -223,7 +230,7 @@ def truncate_messages(
|
||||
)
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||
|
||||
return messages + [system_message]
|
||||
return messages + [system_message] if system_message else messages
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
|
||||
Reference in New Issue
Block a user