mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 05:40:17 +00:00
Use max_prompt_size, tokenizer from config for chat model context stuffing
This commit is contained in:
@@ -127,6 +127,8 @@ def converse_offline(
|
|||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_command=ConversationCommand.Default,
|
conversation_command=ConversationCommand.Default,
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Llama
|
Converse with user using Llama
|
||||||
@@ -158,6 +160,8 @@ def converse_offline(
|
|||||||
prompts.system_prompt_message_llamav2,
|
prompts.system_prompt_message_llamav2,
|
||||||
conversation_log,
|
conversation_log,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
tokenizer_name=tokenizer_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
g = ThreadedGenerator(references, completion_func=completion_func)
|
g = ThreadedGenerator(references, completion_func=completion_func)
|
||||||
|
|||||||
@@ -116,6 +116,8 @@ def converse(
|
|||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_command=ConversationCommand.Default,
|
conversation_command=ConversationCommand.Default,
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
@@ -141,6 +143,8 @@ def converse(
|
|||||||
prompts.personality.format(),
|
prompts.personality.format(),
|
||||||
conversation_log,
|
conversation_log,
|
||||||
model,
|
model,
|
||||||
|
max_prompt_size,
|
||||||
|
tokenizer_name,
|
||||||
)
|
)
|
||||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||||
|
|||||||
@@ -13,17 +13,16 @@ from transformers import AutoTokenizer
|
|||||||
import queue
|
import queue
|
||||||
from khoj.utils.helpers import merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
max_prompt_size = {
|
model_to_prompt_size = {
|
||||||
"gpt-3.5-turbo": 4096,
|
"gpt-3.5-turbo": 4096,
|
||||||
"gpt-4": 8192,
|
"gpt-4": 8192,
|
||||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
||||||
"gpt-3.5-turbo-16k": 15000,
|
"gpt-3.5-turbo-16k": 15000,
|
||||||
"default": 1600,
|
|
||||||
}
|
}
|
||||||
tokenizer = {
|
model_to_tokenizer = {
|
||||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
||||||
"default": "hf-internal-testing/llama-tokenizer",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -86,7 +85,13 @@ def message_to_log(
|
|||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
|
user_message,
|
||||||
|
system_message,
|
||||||
|
conversation_log={},
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
lookback_turns=2,
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
):
|
):
|
||||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||||
# Extract Chat History for Context
|
# Extract Chat History for Context
|
||||||
@@ -108,20 +113,38 @@ def generate_chatml_messages_with_context(
|
|||||||
|
|
||||||
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
||||||
|
|
||||||
|
# Set max prompt size from user config, pre-configured for model or to default prompt size
|
||||||
|
try:
|
||||||
|
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
|
||||||
|
except:
|
||||||
|
max_prompt_size = 2000
|
||||||
|
logger.warning(
|
||||||
|
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
|
||||||
|
)
|
||||||
|
|
||||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||||
messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name)
|
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||||
|
|
||||||
# Return message in chronological order
|
# Return message in chronological order
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]:
|
def truncate_messages(
|
||||||
|
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
|
||||||
|
) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
|
try:
|
||||||
if model_name.startswith("gpt-"):
|
if model_name.startswith("gpt-"):
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
else:
|
else:
|
||||||
encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"]))
|
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||||
|
except:
|
||||||
|
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||||
|
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||||
|
logger.warning(
|
||||||
|
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||||
|
)
|
||||||
|
|
||||||
system_message = messages.pop()
|
system_message = messages.pop()
|
||||||
system_message_tokens = len(encoder.encode(system_message.content))
|
system_message_tokens = len(encoder.encode(system_message.content))
|
||||||
|
|||||||
@@ -123,6 +123,8 @@ def generate_chat_response(
|
|||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_command=conversation_command,
|
conversation_command=conversation_command,
|
||||||
model=state.processor_config.conversation.offline_chat.chat_model,
|
model=state.processor_config.conversation.offline_chat.chat_model,
|
||||||
|
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||||
|
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif state.processor_config.conversation.openai_model:
|
elif state.processor_config.conversation.openai_model:
|
||||||
@@ -136,6 +138,8 @@ def generate_chat_response(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_command=conversation_command,
|
conversation_command=conversation_command,
|
||||||
|
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||||
|
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ class ConversationProcessorConfigModel:
|
|||||||
self.openai_model = conversation_config.openai
|
self.openai_model = conversation_config.openai
|
||||||
self.gpt4all_model = GPT4AllProcessorConfig()
|
self.gpt4all_model = GPT4AllProcessorConfig()
|
||||||
self.offline_chat = conversation_config.offline_chat
|
self.offline_chat = conversation_config.offline_chat
|
||||||
|
self.max_prompt_size = conversation_config.max_prompt_size
|
||||||
|
self.tokenizer = conversation_config.tokenizer
|
||||||
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
||||||
self.chat_session: List[str] = []
|
self.chat_session: List[str] = []
|
||||||
self.meta_log: dict = {}
|
self.meta_log: dict = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user