Use max_prompt_size, tokenizer from config for chat model context stuffing

This commit is contained in:
Debanjum Singh Solanky
2023-10-15 16:33:26 -07:00
parent 116595b351
commit df1d74a879
5 changed files with 48 additions and 11 deletions

View File

@@ -127,6 +127,8 @@ def converse_offline(
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
completion_func=None, completion_func=None,
conversation_command=ConversationCommand.Default, conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Union[ThreadedGenerator, Iterator[str]]:
""" """
Converse with user using Llama Converse with user using Llama
@@ -158,6 +160,8 @@ def converse_offline(
prompts.system_prompt_message_llamav2, prompts.system_prompt_message_llamav2,
conversation_log, conversation_log,
model_name=model, model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
) )
g = ThreadedGenerator(references, completion_func=completion_func) g = ThreadedGenerator(references, completion_func=completion_func)

View File

@@ -116,6 +116,8 @@ def converse(
temperature: float = 0.2, temperature: float = 0.2,
completion_func=None, completion_func=None,
conversation_command=ConversationCommand.Default, conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
): ):
""" """
Converse with user using OpenAI's ChatGPT Converse with user using OpenAI's ChatGPT
@@ -141,6 +143,8 @@ def converse(
prompts.personality.format(), prompts.personality.format(),
conversation_log, conversation_log,
model, model,
max_prompt_size,
tokenizer_name,
) )
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}") logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View File

@@ -13,17 +13,16 @@ from transformers import AutoTokenizer
import queue import queue
from khoj.utils.helpers import merge_dicts from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
max_prompt_size = { model_to_prompt_size = {
"gpt-3.5-turbo": 4096, "gpt-3.5-turbo": 4096,
"gpt-4": 8192, "gpt-4": 8192,
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
"gpt-3.5-turbo-16k": 15000, "gpt-3.5-turbo-16k": 15000,
"default": 1600,
} }
tokenizer = { model_to_tokenizer = {
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer", "llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
"default": "hf-internal-testing/llama-tokenizer",
} }
@@ -86,7 +85,13 @@ def message_to_log(
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2 user_message,
system_message,
conversation_log={},
model_name="gpt-3.5-turbo",
lookback_turns=2,
max_prompt_size=None,
tokenizer_name=None,
): ):
"""Generate messages for ChatGPT with context from previous conversation""" """Generate messages for ChatGPT with context from previous conversation"""
# Extract Chat History for Context # Extract Chat History for Context
@@ -108,20 +113,38 @@ def generate_chatml_messages_with_context(
messages = user_chatml_message + rest_backnforths + system_chatml_message messages = user_chatml_message + rest_backnforths + system_chatml_message
# Set max prompt size from user config, pre-configured for model or to default prompt size
try:
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
except:
max_prompt_size = 2000
logger.warning(
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
)
# Truncate oldest messages from conversation history until under max supported prompt size by model # Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name) messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return message in chronological order # Return message in chronological order
return messages[::-1] return messages[::-1]
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]: def truncate_messages(
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
try:
if model_name.startswith("gpt-"): if model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
else: else:
encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"])) encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
except:
default_tokenizer = "hf-internal-testing/llama-tokenizer"
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
logger.warning(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
system_message = messages.pop() system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content)) system_message_tokens = len(encoder.encode(system_message.content))

View File

@@ -123,6 +123,8 @@ def generate_chat_response(
completion_func=partial_completion, completion_func=partial_completion,
conversation_command=conversation_command, conversation_command=conversation_command,
model=state.processor_config.conversation.offline_chat.chat_model, model=state.processor_config.conversation.offline_chat.chat_model,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
) )
elif state.processor_config.conversation.openai_model: elif state.processor_config.conversation.openai_model:
@@ -136,6 +138,8 @@ def generate_chat_response(
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
conversation_command=conversation_command, conversation_command=conversation_command,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
) )
except Exception as e: except Exception as e:

View File

@@ -95,6 +95,8 @@ class ConversationProcessorConfigModel:
self.openai_model = conversation_config.openai self.openai_model = conversation_config.openai
self.gpt4all_model = GPT4AllProcessorConfig() self.gpt4all_model = GPT4AllProcessorConfig()
self.offline_chat = conversation_config.offline_chat self.offline_chat = conversation_config.offline_chat
self.max_prompt_size = conversation_config.max_prompt_size
self.tokenizer = conversation_config.tokenizer
self.conversation_logfile = Path(conversation_config.conversation_logfile) self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = [] self.chat_session: List[str] = []
self.meta_log: dict = {} self.meta_log: dict = {}