mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Merge pull request #224 from debanjum/fix/message-exceeds-prompt-size
Pass truncated message as string in ChatMessage when exceeding max prompt size
This commit is contained in:
@@ -78,6 +78,8 @@ dev = [
|
|||||||
"black >= 23.1.0",
|
"black >= 23.1.0",
|
||||||
"pre-commit >= 3.0.4",
|
"pre-commit >= 3.0.4",
|
||||||
"freezegun >= 1.2.0",
|
"freezegun >= 1.2.0",
|
||||||
|
"factory-boy==3.2.1",
|
||||||
|
"Faker==18.10.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.version]
|
[tool.hatch.version]
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def completion_with_backoff(**kwargs):
|
|||||||
kwargs["openai_api_key"] = kwargs.get("api_key")
|
kwargs["openai_api_key"] = kwargs.get("api_key")
|
||||||
else:
|
else:
|
||||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||||
llm = OpenAI(**kwargs, request_timeout=10, max_retries=1)
|
llm = OpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||||
return llm(prompt)
|
return llm(prompt)
|
||||||
|
|
||||||
|
|
||||||
@@ -97,25 +97,35 @@ def generate_chatml_messages_with_context(
|
|||||||
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
||||||
|
|
||||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
|
||||||
tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content])
|
|
||||||
while tokens > max_prompt_size[model_name] and len(messages) > 1:
|
|
||||||
messages.pop()
|
|
||||||
tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content])
|
|
||||||
|
|
||||||
# Truncate last message if still over max supported prompt size by model
|
|
||||||
if tokens > max_prompt_size[model_name]:
|
|
||||||
last_message = messages[-1]
|
|
||||||
truncated_message = encoder.decode(encoder.encode(last_message.content))
|
|
||||||
logger.debug(
|
|
||||||
f"Truncate last message to fit within max prompt size of {max_prompt_size[model_name]} supported by {model_name} model:\n {truncated_message}"
|
|
||||||
)
|
|
||||||
messages = [ChatMessage(content=[truncated_message], role=last_message.role)]
|
|
||||||
|
|
||||||
# Return message in chronological order
|
# Return message in chronological order
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_messages(messages, max_prompt_size, model_name):
|
||||||
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||||
|
while tokens > max_prompt_size and len(messages) > 1:
|
||||||
|
messages.pop()
|
||||||
|
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||||
|
|
||||||
|
# Truncate last message if still over max supported prompt size by model
|
||||||
|
if tokens > max_prompt_size:
|
||||||
|
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
|
||||||
|
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
|
||||||
|
original_question_tokens = len(encoder.encode(original_question))
|
||||||
|
remaining_tokens = max_prompt_size - original_question_tokens
|
||||||
|
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
||||||
|
logger.debug(
|
||||||
|
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||||
|
)
|
||||||
|
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def reciprocal_conversation_to_chatml(message_pair):
|
def reciprocal_conversation_to_chatml(message_pair):
|
||||||
"""Convert a single back and forth between user and assistant to chatml format"""
|
"""Convert a single back and forth between user and assistant to chatml format"""
|
||||||
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||||
|
|||||||
65
tests/test_conversation_utils.py
Normal file
65
tests/test_conversation_utils.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from khoj.processor.conversation import utils
|
||||||
|
from langchain.schema import ChatMessage
|
||||||
|
import factory
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageFactory(factory.Factory):
|
||||||
|
class Meta:
|
||||||
|
model = ChatMessage
|
||||||
|
|
||||||
|
content = factory.Faker("paragraph")
|
||||||
|
role = factory.Faker("name")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncateMessage:
|
||||||
|
max_prompt_size = 4096
|
||||||
|
model_name = "gpt-3.5-turbo"
|
||||||
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
|
def test_truncate_message_all_small(self):
|
||||||
|
chat_messages = ChatMessageFactory.build_batch(500)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
|
||||||
|
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||||
|
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert len(chat_messages) < 500
|
||||||
|
assert len(chat_messages) > 1
|
||||||
|
assert prompt == chat_messages
|
||||||
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
|
def test_truncate_message_first_large(self):
|
||||||
|
chat_messages = ChatMessageFactory.build_batch(25)
|
||||||
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||||
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
chat_messages.insert(0, big_chat_message)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
|
||||||
|
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||||
|
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert len(chat_messages) == 1
|
||||||
|
assert prompt[0] != copy_big_chat_message
|
||||||
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
|
def test_truncate_message_last_large(self):
|
||||||
|
chat_messages = ChatMessageFactory.build_batch(25)
|
||||||
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||||
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
|
||||||
|
chat_messages.append(big_chat_message)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
|
||||||
|
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||||
|
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert len(chat_messages) < 26
|
||||||
|
assert len(chat_messages) > 1
|
||||||
|
assert prompt[0] != copy_big_chat_message
|
||||||
|
assert tokens <= self.max_prompt_size
|
||||||
Reference in New Issue
Block a user