diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index e9beaa80..7e92d002 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -127,6 +127,8 @@ def converse_offline( loaded_model: Union[Any, None] = None, completion_func=None, conversation_command=ConversationCommand.Default, + max_prompt_size=None, + tokenizer_name=None, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -158,6 +160,8 @@ def converse_offline( prompts.system_prompt_message_llamav2, conversation_log, model_name=model, + max_prompt_size=max_prompt_size, + tokenizer_name=tokenizer_name, ) g = ThreadedGenerator(references, completion_func=completion_func) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 96510586..73b4f176 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -116,6 +116,8 @@ def converse( temperature: float = 0.2, completion_func=None, conversation_command=ConversationCommand.Default, + max_prompt_size=None, + tokenizer_name=None, ): """ Converse with user using OpenAI's ChatGPT @@ -141,6 +143,8 @@ def converse( prompts.personality.format(), conversation_log, model, + max_prompt_size, + tokenizer_name, ) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) logger.debug(f"Conversation Context for GPT: {truncated_messages}") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 7bb86887..5f219b83 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -13,17 +13,16 @@ from transformers import AutoTokenizer import queue from khoj.utils.helpers import merge_dicts + logger = logging.getLogger(__name__) -max_prompt_size = { +model_to_prompt_size = { "gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "gpt-3.5-turbo-16k": 15000, - "default": 1600, } -tokenizer = { +model_to_tokenizer = { "llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer", - "default": "hf-internal-testing/llama-tokenizer", } @@ -86,7 +85,13 @@ def message_to_log( def generate_chatml_messages_with_context( - user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2 + user_message, + system_message, + conversation_log={}, + model_name="gpt-3.5-turbo", + lookback_turns=2, + max_prompt_size=None, + tokenizer_name=None, ): """Generate messages for ChatGPT with context from previous conversation""" # Extract Chat History for Context @@ -108,20 +113,38 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message + # Set max prompt size from user config, pre-configured for model or to default prompt size + try: + max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] + except: + max_prompt_size = 2000 + logger.warning( + f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window." + ) + # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name) + messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) # Return message in chronological order return messages[::-1] -def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]: +def truncate_messages( + messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None +) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - if model_name.startswith("gpt-"): - encoder = tiktoken.encoding_for_model(model_name) - else: - encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"])) + try: + if model_name.startswith("gpt-"): + encoder = tiktoken.encoding_for_model(model_name) + else: + encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) + except: + default_tokenizer = "hf-internal-testing/llama-tokenizer" + encoder = AutoTokenizer.from_pretrained(default_tokenizer) + logger.warning( + f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." + ) system_message = messages.pop() system_message_tokens = len(encoder.encode(system_message.content)) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d8b0aa8b..6b42f29c 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -123,6 +123,8 @@ def generate_chat_response( completion_func=partial_completion, conversation_command=conversation_command, model=state.processor_config.conversation.offline_chat.chat_model, + max_prompt_size=state.processor_config.conversation.max_prompt_size, + tokenizer_name=state.processor_config.conversation.tokenizer, ) elif state.processor_config.conversation.openai_model: @@ -136,6 +138,8 @@ def generate_chat_response( api_key=api_key, completion_func=partial_completion, conversation_command=conversation_command, + max_prompt_size=state.processor_config.conversation.max_prompt_size, + tokenizer_name=state.processor_config.conversation.tokenizer, ) except Exception as e: diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index daae1982..3930ec98 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -95,6 +95,8 @@ class ConversationProcessorConfigModel: self.openai_model = conversation_config.openai self.gpt4all_model = GPT4AllProcessorConfig() self.offline_chat = conversation_config.offline_chat + self.max_prompt_size = conversation_config.max_prompt_size + self.tokenizer = conversation_config.tokenizer self.conversation_logfile = Path(conversation_config.conversation_logfile) self.chat_session: List[str] = [] self.meta_log: dict = {}