diff --git a/pyproject.toml b/pyproject.toml index 1750268d..cf77ea79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,8 @@ dev = [ "black >= 23.1.0", "pre-commit >= 3.0.4", "freezegun >= 1.2.0", + "factory-boy==3.2.1", + "Faker==18.10.1", ] [tool.hatch.version] diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5ea5817c..a3901d02 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -45,7 +45,7 @@ def completion_with_backoff(**kwargs): kwargs["openai_api_key"] = kwargs.get("api_key") else: kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") - llm = OpenAI(**kwargs, request_timeout=10, max_retries=1) + llm = OpenAI(**kwargs, request_timeout=20, max_retries=1) return llm(prompt) @@ -97,25 +97,35 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message # Truncate oldest messages from conversation history until under max supported prompt size by model - encoder = tiktoken.encoding_for_model(model_name) - tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content]) - while tokens > max_prompt_size[model_name] and len(messages) > 1: - messages.pop() - tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content]) - - # Truncate last message if still over max supported prompt size by model - if tokens > max_prompt_size[model_name]: - last_message = messages[-1] - truncated_message = encoder.decode(encoder.encode(last_message.content)) - logger.debug( - f"Truncate last message to fit within max prompt size of {max_prompt_size[model_name]} supported by {model_name} model:\n {truncated_message}" - ) - messages = [ChatMessage(content=[truncated_message], role=last_message.role)] + messages = truncate_messages(messages, max_prompt_size[model_name], model_name) # Return message in chronological order return messages[::-1] +def truncate_messages(messages, max_prompt_size, model_name): + """Truncate messages to fit within max prompt size supported by model""" + encoder = tiktoken.encoding_for_model(model_name) + tokens = sum([len(encoder.encode(message.content)) for message in messages]) + while tokens > max_prompt_size and len(messages) > 1: + messages.pop() + tokens = sum([len(encoder.encode(message.content)) for message in messages]) + + # Truncate last message if still over max supported prompt size by model + if tokens > max_prompt_size: + last_message = "\n".join(messages[-1].content.split("\n")[:-1]) + original_question = "\n".join(messages[-1].content.split("\n")[-1:]) + original_question_tokens = len(encoder.encode(original_question)) + remaining_tokens = max_prompt_size - original_question_tokens + truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip() + logger.debug( + f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" + ) + messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)] + + return messages + + def reciprocal_conversation_to_chatml(message_pair): """Convert a single back and forth between user and assistant to chatml format""" return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py new file mode 100644 index 00000000..06a507c5 --- /dev/null +++ b/tests/test_conversation_utils.py @@ -0,0 +1,65 @@ +from khoj.processor.conversation import utils +from langchain.schema import ChatMessage +import factory +import tiktoken + + +class ChatMessageFactory(factory.Factory): + class Meta: + model = ChatMessage + + content = factory.Faker("paragraph") + role = factory.Faker("name") + + +class TestTruncateMessage: + max_prompt_size = 4096 + model_name = "gpt-3.5-turbo" + encoder = tiktoken.encoding_for_model(model_name) + + def test_truncate_message_all_small(self): + chat_messages = ChatMessageFactory.build_batch(500) + tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + + # The original object has been modified. Verify certain properties + assert len(chat_messages) < 500 + assert len(chat_messages) > 1 + assert prompt == chat_messages + assert tokens <= self.max_prompt_size + + def test_truncate_message_first_large(self): + chat_messages = ChatMessageFactory.build_batch(25) + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) + big_chat_message.content = big_chat_message.content + "\n" + "Question?" + copy_big_chat_message = big_chat_message.copy() + chat_messages.insert(0, big_chat_message) + tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + + # The original object has been modified. Verify certain properties + assert len(chat_messages) == 1 + assert prompt[0] != copy_big_chat_message + assert tokens <= self.max_prompt_size + + def test_truncate_message_last_large(self): + chat_messages = ChatMessageFactory.build_batch(25) + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) + big_chat_message.content = big_chat_message.content + "\n" + "Question?" + copy_big_chat_message = big_chat_message.copy() + + chat_messages.append(big_chat_message) + tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + + # The original object has been modified. Verify certain properties + assert len(chat_messages) < 26 + assert len(chat_messages) > 1 + assert prompt[0] != copy_big_chat_message + assert tokens <= self.max_prompt_size