mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Run pre-commit script
This commit is contained in:
@@ -3,16 +3,18 @@ from langchain.schema import ChatMessage
|
||||
import factory
|
||||
import tiktoken
|
||||
|
||||
|
||||
class ChatMessageFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = ChatMessage
|
||||
|
||||
content = factory.Faker('paragraph')
|
||||
role = factory.Faker('name')
|
||||
content = factory.Faker("paragraph")
|
||||
role = factory.Faker("name")
|
||||
|
||||
|
||||
class TestTruncateMessage:
|
||||
max_prompt_size = 4096
|
||||
model_name = 'gpt-3.5-turbo'
|
||||
model_name = "gpt-3.5-turbo"
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
def test_truncate_message_all_small(self):
|
||||
@@ -33,7 +35,7 @@ class TestTruncateMessage:
|
||||
|
||||
def test_truncate_message_first_large(self):
|
||||
chat_messages = ChatMessageFactory.build_batch(25)
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000))
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages.insert(0, big_chat_message)
|
||||
@@ -53,10 +55,10 @@ class TestTruncateMessage:
|
||||
|
||||
def test_truncate_message_last_large(self):
|
||||
chat_messages = ChatMessageFactory.build_batch(25)
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000))
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
|
||||
|
||||
chat_messages.append(big_chat_message)
|
||||
assert len(chat_messages) == 26
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
@@ -71,4 +73,3 @@ class TestTruncateMessage:
|
||||
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||
assert tokens < self.max_prompt_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user