From f0efe0177e0c652e2c6b5dfd1aab90a2d779e859 Mon Sep 17 00:00:00 2001 From: Saba Date: Sun, 4 Jun 2023 19:33:46 -0700 Subject: [PATCH 1/9] Pass truncated message as string in ChatMessage when exceeding max prompt size --- src/khoj/processor/conversation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5ea5817c..2df1f24f 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -110,7 +110,7 @@ def generate_chatml_messages_with_context( logger.debug( f"Truncate last message to fit within max prompt size of {max_prompt_size[model_name]} supported by {model_name} model:\n {truncated_message}" ) - messages = [ChatMessage(content=[truncated_message], role=last_message.role)] + messages = [ChatMessage(content=truncated_message, role=last_message.role)] # Return message in chronological order return messages[::-1] From 0e63a9037750d16c0bb14bbab16bfadfc0aeeb27 Mon Sep 17 00:00:00 2001 From: Saba Date: Sun, 4 Jun 2023 20:25:37 -0700 Subject: [PATCH 2/9] Fix the mechanism to retrieve the message content --- src/khoj/processor/conversation/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 2df1f24f..5bccdf79 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -98,10 +98,10 @@ def generate_chatml_messages_with_context( # Truncate oldest messages from conversation history until under max supported prompt size by model encoder = tiktoken.encoding_for_model(model_name) - tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content]) + tokens = sum([len(encoder.encode(message.content)) for message in messages]) while tokens > max_prompt_size[model_name] and len(messages) > 1: messages.pop() - tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content]) + tokens = sum([len(encoder.encode(message.content)) for message in messages]) # Truncate last message if still over max supported prompt size by model if tokens > max_prompt_size[model_name]: From 5f4223efb486d5c404acee4af5da0bd8e011e0f4 Mon Sep 17 00:00:00 2001 From: Saba Date: Sun, 4 Jun 2023 20:49:47 -0700 Subject: [PATCH 3/9] Increase timeout to OpenAI call --- src/khoj/processor/conversation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5bccdf79..0161c3e1 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -45,7 +45,7 @@ def completion_with_backoff(**kwargs): kwargs["openai_api_key"] = kwargs.get("api_key") else: kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") - llm = OpenAI(**kwargs, request_timeout=10, max_retries=1) + llm = OpenAI(**kwargs, request_timeout=20, max_retries=1) return llm(prompt) From f65ff9815def67ccd63f58b2a0bf72f74a58a8a3 Mon Sep 17 00:00:00 2001 From: Saba Date: Mon, 5 Jun 2023 18:58:29 -0700 Subject: [PATCH 4/9] Move message truncation logic into a separate function. Add unit tests with factory boy. --- pyproject.toml | 2 + src/khoj/processor/conversation/utils.py | 26 +++++--- tests/test_conversation_utils.py | 78 ++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 8 deletions(-) create mode 100644 tests/test_conversation_utils.py diff --git a/pyproject.toml b/pyproject.toml index 1750268d..08921ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,8 @@ dependencies = [ "aiohttp == 3.8.4", "langchain >= 0.0.187", "pypdf >= 3.9.0", + "factory-boy==3.2.1", + "Faker==18.10.1" ] dynamic = ["version"] diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 0161c3e1..d81ec648 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -97,23 +97,33 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message # Truncate oldest messages from conversation history until under max supported prompt size by model + messages = truncate_message(messages, max_prompt_size[model_name], model_name) + + # Return message in chronological order + return messages[::-1] + +def truncate_message(messages, max_prompt_size, model_name): + """Truncate messages to fit within max prompt size supported by model""" encoder = tiktoken.encoding_for_model(model_name) tokens = sum([len(encoder.encode(message.content)) for message in messages]) - while tokens > max_prompt_size[model_name] and len(messages) > 1: + logger.info(f"num tokens: {tokens}") + while tokens > max_prompt_size and len(messages) > 1: messages.pop() tokens = sum([len(encoder.encode(message.content)) for message in messages]) # Truncate last message if still over max supported prompt size by model - if tokens > max_prompt_size[model_name]: - last_message = messages[-1] - truncated_message = encoder.decode(encoder.encode(last_message.content)) + if tokens > max_prompt_size: + last_message = '\n'.join(messages[-1].content.split("\n")[:-1]) + original_question = '\n'.join(messages[-1].content.split("\n")[-1:]) + original_question_tokens = len(encoder.encode(original_question)) + remaining_tokens = max_prompt_size - original_question_tokens + truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip() logger.debug( - f"Truncate last message to fit within max prompt size of {max_prompt_size[model_name]} supported by {model_name} model:\n {truncated_message}" + f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" ) - messages = [ChatMessage(content=truncated_message, role=last_message.role)] + messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)] - # Return message in chronological order - return messages[::-1] + return messages def reciprocal_conversation_to_chatml(message_pair): diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py new file mode 100644 index 00000000..ed68a40d --- /dev/null +++ b/tests/test_conversation_utils.py @@ -0,0 +1,78 @@ +from khoj.processor.conversation import utils +from langchain.schema import ChatMessage +import factory +import logging +import tiktoken + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +class ChatMessageFactory(factory.Factory): + class Meta: + model = ChatMessage + + content = factory.Faker('paragraph') + role = factory.Faker('name') + +class TestTruncateMessage: + max_prompt_size = 4096 + model_name = 'gpt-3.5-turbo' + encoder = tiktoken.encoding_for_model(model_name) + + def test_truncate_message_all_small(self): + chat_messages = ChatMessageFactory.build_batch(500) + assert len(chat_messages) == 500 + tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + assert tokens > self.max_prompt_size + + prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + + # The original object has been modified. Verify certain properties + assert len(chat_messages) < 500 + assert len(chat_messages) > 1 + assert prompt == chat_messages + + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + assert tokens <= self.max_prompt_size + + def test_truncate_message_first_large(self): + chat_messages = ChatMessageFactory.build_batch(25) + big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000)) + big_chat_message.content = big_chat_message.content + "\n" + "Question?" + copy_big_chat_message = big_chat_message.copy() + chat_messages.insert(0, big_chat_message) + assert len(chat_messages) == 26 + tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + assert tokens > self.max_prompt_size + + prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + + # The original object has been modified. Verify certain properties + assert len(chat_messages) < 26 + assert len(chat_messages) == 1 + assert prompt[0] != copy_big_chat_message + + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + assert tokens <= self.max_prompt_size + + def test_truncate_message_last_large(self): + chat_messages = ChatMessageFactory.build_batch(25) + big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000)) + big_chat_message.content = big_chat_message.content + "\n" + "Question?" + copy_big_chat_message = big_chat_message.copy() + + chat_messages.append(big_chat_message) + assert len(chat_messages) == 26 + tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + assert tokens > self.max_prompt_size + + prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + + # The original object has been modified. Verify certain properties + assert len(chat_messages) < 26 + assert len(chat_messages) > 1 + assert prompt[0] != copy_big_chat_message + + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + assert tokens < self.max_prompt_size + From 6212d7c2e8803300215e53f7e85296aa0ecb2747 Mon Sep 17 00:00:00 2001 From: Saba Date: Mon, 5 Jun 2023 19:00:25 -0700 Subject: [PATCH 5/9] Remove debug line --- src/khoj/processor/conversation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index d81ec648..9cf9952f 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -106,7 +106,6 @@ def truncate_message(messages, max_prompt_size, model_name): """Truncate messages to fit within max prompt size supported by model""" encoder = tiktoken.encoding_for_model(model_name) tokens = sum([len(encoder.encode(message.content)) for message in messages]) - logger.info(f"num tokens: {tokens}") while tokens > max_prompt_size and len(messages) > 1: messages.pop() tokens = sum([len(encoder.encode(message.content)) for message in messages]) From 948ba6ddca56952f9080e1b8f16443357b794930 Mon Sep 17 00:00:00 2001 From: Saba Date: Mon, 5 Jun 2023 19:01:03 -0700 Subject: [PATCH 6/9] Remove unused logger --- tests/test_conversation_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index ed68a40d..24e97938 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -1,12 +1,8 @@ from khoj.processor.conversation import utils from langchain.schema import ChatMessage import factory -import logging import tiktoken -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - class ChatMessageFactory(factory.Factory): class Meta: model = ChatMessage From 7119ed08498dc54547871d08e283b7dde1853989 Mon Sep 17 00:00:00 2001 From: Saba Date: Mon, 5 Jun 2023 19:29:23 -0700 Subject: [PATCH 7/9] Run pre-commit script --- src/khoj/processor/conversation/utils.py | 5 +++-- tests/test_conversation_utils.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 9cf9952f..a63a09c0 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -102,6 +102,7 @@ def generate_chatml_messages_with_context( # Return message in chronological order return messages[::-1] + def truncate_message(messages, max_prompt_size, model_name): """Truncate messages to fit within max prompt size supported by model""" encoder = tiktoken.encoding_for_model(model_name) @@ -112,8 +113,8 @@ def truncate_message(messages, max_prompt_size, model_name): # Truncate last message if still over max supported prompt size by model if tokens > max_prompt_size: - last_message = '\n'.join(messages[-1].content.split("\n")[:-1]) - original_question = '\n'.join(messages[-1].content.split("\n")[-1:]) + last_message = "\n".join(messages[-1].content.split("\n")[:-1]) + original_question = "\n".join(messages[-1].content.split("\n")[-1:]) original_question_tokens = len(encoder.encode(original_question)) remaining_tokens = max_prompt_size - original_question_tokens truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip() diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 24e97938..43f68884 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -3,16 +3,18 @@ from langchain.schema import ChatMessage import factory import tiktoken + class ChatMessageFactory(factory.Factory): class Meta: model = ChatMessage - content = factory.Faker('paragraph') - role = factory.Faker('name') + content = factory.Faker("paragraph") + role = factory.Faker("name") + class TestTruncateMessage: max_prompt_size = 4096 - model_name = 'gpt-3.5-turbo' + model_name = "gpt-3.5-turbo" encoder = tiktoken.encoding_for_model(model_name) def test_truncate_message_all_small(self): @@ -33,7 +35,7 @@ class TestTruncateMessage: def test_truncate_message_first_large(self): chat_messages = ChatMessageFactory.build_batch(25) - big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000)) + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() chat_messages.insert(0, big_chat_message) @@ -53,10 +55,10 @@ class TestTruncateMessage: def test_truncate_message_last_large(self): chat_messages = ChatMessageFactory.build_batch(25) - big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000)) + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() - + chat_messages.append(big_chat_message) assert len(chat_messages) == 26 tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) @@ -71,4 +73,3 @@ class TestTruncateMessage: tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) assert tokens < self.max_prompt_size - From 5d5ebcbf7ca9377fced10d902c5072203a2004e7 Mon Sep 17 00:00:00 2001 From: Saba Date: Tue, 6 Jun 2023 23:25:43 -0700 Subject: [PATCH 8/9] Rename truncate messages method and update unit tests to simplify assertion logic --- src/khoj/processor/conversation/utils.py | 4 ++-- tests/test_conversation_utils.py | 24 +++++++----------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a63a09c0..a3901d02 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -97,13 +97,13 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_message(messages, max_prompt_size[model_name], model_name) + messages = truncate_messages(messages, max_prompt_size[model_name], model_name) # Return message in chronological order return messages[::-1] -def truncate_message(messages, max_prompt_size, model_name): +def truncate_messages(messages, max_prompt_size, model_name): """Truncate messages to fit within max prompt size supported by model""" encoder = tiktoken.encoding_for_model(model_name) tokens = sum([len(encoder.encode(message.content)) for message in messages]) diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 43f68884..06a507c5 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -19,18 +19,15 @@ class TestTruncateMessage: def test_truncate_message_all_small(self): chat_messages = ChatMessageFactory.build_batch(500) - assert len(chat_messages) == 500 tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) - assert tokens > self.max_prompt_size - prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) # The original object has been modified. Verify certain properties assert len(chat_messages) < 500 assert len(chat_messages) > 1 assert prompt == chat_messages - - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) assert tokens <= self.max_prompt_size def test_truncate_message_first_large(self): @@ -39,18 +36,14 @@ class TestTruncateMessage: big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() chat_messages.insert(0, big_chat_message) - assert len(chat_messages) == 26 tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) - assert tokens > self.max_prompt_size - prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) # The original object has been modified. Verify certain properties - assert len(chat_messages) < 26 assert len(chat_messages) == 1 assert prompt[0] != copy_big_chat_message - - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) assert tokens <= self.max_prompt_size def test_truncate_message_last_large(self): @@ -60,16 +53,13 @@ class TestTruncateMessage: copy_big_chat_message = big_chat_message.copy() chat_messages.append(big_chat_message) - assert len(chat_messages) == 26 tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) - assert tokens > self.max_prompt_size - prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) # The original object has been modified. Verify certain properties assert len(chat_messages) < 26 assert len(chat_messages) > 1 assert prompt[0] != copy_big_chat_message - - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) - assert tokens < self.max_prompt_size + assert tokens <= self.max_prompt_size From c5666e04047d550e7b5b5a62861b5642b14bca26 Mon Sep 17 00:00:00 2001 From: Saba Date: Tue, 6 Jun 2023 23:26:24 -0700 Subject: [PATCH 9/9] Move factory dependencies to optional settings --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 08921ab5..cf77ea79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,8 +56,6 @@ dependencies = [ "aiohttp == 3.8.4", "langchain >= 0.0.187", "pypdf >= 3.9.0", - "factory-boy==3.2.1", - "Faker==18.10.1" ] dynamic = ["version"] @@ -80,6 +78,8 @@ dev = [ "black >= 23.1.0", "pre-commit >= 3.0.4", "freezegun >= 1.2.0", + "factory-boy==3.2.1", + "Faker==18.10.1", ] [tool.hatch.version]