mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 05:39:12 +00:00
Use llama.cpp for offline chat models
- Benefits of moving to llama-cpp-python from gpt4all:
- Support for all GGUF format chat models
- Support for AMD, Nvidia, Mac, Vulcan GPU machines (instead of just Vulcan, Mac)
- Supports models with more capabilities like tools, schema
enforcement, speculative ddecoding, image gen etc.
- Upgrade default chat model, prompt size, tokenizer for new supported
chat models
- Load offline chat model when present on disk without requiring internet
- Load model onto GPU if not disabled and device has GPU
- Load model onto CPU if loading model onto GPU fails
- Create helper function to check and load model from disk, when model
glob is present on disk.
`Llama.from_pretrained' needs internet to get repo info from
HuggingFace. This isn't required, if the model is already downloaded
Didn't find any existing HF or llama.cpp method that looked for model
glob on disk without internet
This commit is contained in:
@@ -160,7 +160,7 @@ class ChatModelOptions(BaseModel):
|
||||
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
|
||||
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Any, Iterator, List, Union
|
||||
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
generate_chatml_messages_with_context,
|
||||
@@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_questions_offline(
|
||||
text: str,
|
||||
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||
model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
conversation_log={},
|
||||
use_history: bool = True,
|
||||
@@ -31,22 +32,14 @@ def extract_questions_offline(
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
# Assert that loaded_model is either None or of type GPT4All
|
||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
||||
|
||||
all_questions = text.split("? ")
|
||||
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
||||
|
||||
if not should_extract_questions:
|
||||
return all_questions
|
||||
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
|
||||
@@ -75,10 +68,12 @@ def extract_questions_offline(
|
||||
next_christmas_date=next_christmas_date,
|
||||
location=location,
|
||||
)
|
||||
message = system_prompt + example_questions
|
||||
messages = generate_chatml_messages_with_context(example_questions, system_message=system_prompt, model_name=model)
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512)
|
||||
response = offline_chat_model.create_chat_completion(messages, max_tokens=200, top_k=2, temp=0)
|
||||
response = response[0]["choices"][0]["message"]["content"]
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
||||
@@ -133,7 +128,7 @@ def converse_offline(
|
||||
references=[],
|
||||
online_results=[],
|
||||
conversation_log={},
|
||||
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||
model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
@@ -145,15 +140,9 @@ def converse_offline(
|
||||
"""
|
||||
Converse with user using Llama
|
||||
"""
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
@@ -191,72 +180,44 @@ def converse_offline(
|
||||
prompts.system_prompt_message_gpt4all.format(current_date=current_date),
|
||||
conversation_log,
|
||||
model_name=model,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
)
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||
user_message = messages[-1]
|
||||
system_message = messages[0]
|
||||
conversation_history = messages[1:-1]
|
||||
|
||||
formatted_messages = [
|
||||
prompts.khoj_message_gpt4all.format(message=message.content)
|
||||
if message.role == "assistant"
|
||||
else prompts.user_message_gpt4all.format(message=message.content)
|
||||
for message in conversation_history
|
||||
]
|
||||
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
chat_history = "".join(formatted_messages)
|
||||
templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
|
||||
templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
|
||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
response_queue: deque[str] = deque(maxlen=3) # Create a response queue with a maximum length of 3
|
||||
hit_stop_phrase = False
|
||||
|
||||
state.chat_lock.acquire()
|
||||
response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages, loaded_model=model, stop=stop_phrases, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
response_queue.append(response)
|
||||
hit_stop_phrase = any(stop_phrase in "".join(response_queue) for stop_phrase in stop_phrases)
|
||||
if hit_stop_phrase:
|
||||
logger.debug(f"Stop response as hit stop phrase: {''.join(response_queue)}")
|
||||
break
|
||||
# Start streaming the response at a lag once the queue is full
|
||||
# This allows stop word testing before sending the response
|
||||
if len(response_queue) == response_queue.maxlen:
|
||||
g.send(response_queue[0])
|
||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||
finally:
|
||||
if not hit_stop_phrase:
|
||||
if len(response_queue) == response_queue.maxlen:
|
||||
# remove already sent reponse chunk
|
||||
response_queue.popleft()
|
||||
# send the remaining response
|
||||
g.send("".join(response_queue))
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
|
||||
|
||||
def send_message_to_model_offline(
|
||||
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message=""
|
||||
) -> str:
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
|
||||
return gpt4all_model.generate(
|
||||
system_message + message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming
|
||||
)
|
||||
messages: List[ChatMessage],
|
||||
loaded_model=None,
|
||||
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||
streaming=False,
|
||||
stop=[],
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
|
||||
if streaming:
|
||||
return response
|
||||
else:
|
||||
return response["choices"][0]["message"].get("content", "")
|
||||
|
||||
@@ -1,43 +1,53 @@
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
|
||||
from khoj.utils import state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_model(model_name: str):
|
||||
try:
|
||||
import gpt4all
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
||||
from llama_cpp.llama import Llama
|
||||
|
||||
# Initialize Model Parameters
|
||||
kwargs = {"n_threads": 4, "n_ctx": 4096, "verbose": False}
|
||||
|
||||
# Decide whether to load model to GPU or CPU
|
||||
chat_model_config = None
|
||||
device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
|
||||
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
|
||||
|
||||
# Check if the model is already downloaded
|
||||
model_path = load_model_from_cache(repo_id, filename)
|
||||
try:
|
||||
# Download the chat model and its config
|
||||
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
|
||||
|
||||
# Try load chat model to GPU if:
|
||||
# 1. Loading chat model to GPU isn't disabled via CLI and
|
||||
# 2. Machine has GPU
|
||||
# 3. GPU has enough free memory to load the chat model with max context length of 4096
|
||||
device = (
|
||||
"gpu"
|
||||
if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096)
|
||||
else "cpu"
|
||||
)
|
||||
except ValueError:
|
||||
device = "cpu"
|
||||
except Exception as e:
|
||||
if chat_model_config is None:
|
||||
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
|
||||
logger.debug(f"Unable to download model config from gpt4all website: {e}")
|
||||
if model_path:
|
||||
chat_model = Llama(model_path, **kwargs)
|
||||
else:
|
||||
raise e
|
||||
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
except:
|
||||
# Load model on CPU if GPU is not available
|
||||
kwargs["n_gpu_layers"], device = 0, "cpu"
|
||||
if model_path:
|
||||
chat_model = Llama(model_path, **kwargs)
|
||||
else:
|
||||
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
|
||||
# Now load the downloaded chat model onto appropriate device
|
||||
chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False)
|
||||
logger.debug(f"Loaded chat model to {device.upper()}.")
|
||||
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
|
||||
|
||||
return chat_model
|
||||
|
||||
|
||||
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||
# Construct the path to the model file in the cache directory
|
||||
repo_org, repo_name = repo_id.split("/")
|
||||
object_id = "--".join([repo_type, repo_org, repo_name])
|
||||
model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename])
|
||||
|
||||
# Check if the model file exists
|
||||
paths = glob.glob(model_path)
|
||||
if paths:
|
||||
return paths[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -3,29 +3,28 @@ import logging
|
||||
import queue
|
||||
from datetime import datetime
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp.llama import Llama
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ClientApplication, KhojUser
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_to_prompt_size = {
|
||||
"gpt-3.5-turbo": 3000,
|
||||
"gpt-3.5-turbo-0125": 3000,
|
||||
"gpt-4-0125-preview": 7000,
|
||||
"gpt-4-turbo-preview": 7000,
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
||||
"mistral-7b-instruct-v0.1.Q4_0.gguf": 1548,
|
||||
}
|
||||
model_to_tokenizer = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
||||
"mistral-7b-instruct-v0.1.Q4_0.gguf": "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
"gpt-3.5-turbo": 12000,
|
||||
"gpt-3.5-turbo-0125": 12000,
|
||||
"gpt-4-0125-preview": 20000,
|
||||
"gpt-4-turbo-preview": 20000,
|
||||
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
|
||||
"NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
|
||||
}
|
||||
model_to_tokenizer: Dict[str, str] = {}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
@@ -134,9 +133,10 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message,
|
||||
system_message,
|
||||
system_message=None,
|
||||
conversation_log={},
|
||||
model_name="gpt-3.5-turbo",
|
||||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
):
|
||||
@@ -159,7 +159,7 @@ def generate_chatml_messages_with_context(
|
||||
chat_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
||||
chat_logs += [chat["message"] + chat_notes]
|
||||
|
||||
rest_backnforths = []
|
||||
rest_backnforths: List[ChatMessage] = []
|
||||
# Extract in reverse chronological order
|
||||
for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]):
|
||||
if len(rest_backnforths) >= 2 * lookback_turns:
|
||||
@@ -176,22 +176,31 @@ def generate_chatml_messages_with_context(
|
||||
messages.append(ChatMessage(content=system_message, role="system"))
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(
|
||||
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
|
||||
messages: list[ChatMessage],
|
||||
max_prompt_size,
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
try:
|
||||
if model_name.startswith("gpt-"):
|
||||
if loaded_model:
|
||||
encoder = loaded_model.tokenizer()
|
||||
elif model_name.startswith("gpt-"):
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||
try:
|
||||
encoder = download_model(model_name).tokenizer()
|
||||
except:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||
except:
|
||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||
|
||||
@@ -370,27 +370,31 @@ async def send_message_to_model_wrapper(
|
||||
if conversation_config is None:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message, system_message=system_message, model_name=conversation_config.chat_model
|
||||
)
|
||||
chat_model = conversation_config.chat_model
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model)
|
||||
|
||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message, system_message=system_message, model_name=chat_model, loaded_model=loaded_model
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
message=truncated_messages[-1].content,
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model=conversation_config.chat_model,
|
||||
model=chat_model,
|
||||
streaming=False,
|
||||
system_message=truncated_messages[0].content,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message, system_message=system_message, model_name=chat_model
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
)
|
||||
|
||||
@@ -75,10 +75,7 @@ class GPT4AllProcessorConfig:
|
||||
|
||||
|
||||
class GPT4AllProcessorModel:
|
||||
def __init__(
|
||||
self,
|
||||
chat_model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||
):
|
||||
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
|
||||
self.chat_model = chat_model
|
||||
self.loaded_model = None
|
||||
try:
|
||||
|
||||
@@ -6,7 +6,7 @@ empty_escape_sequences = "\n|\r|\t| "
|
||||
app_env_filepath = "~/.khoj/env"
|
||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||
content_directory = "~/.khoj/content/"
|
||||
default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
default_offline_chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||
default_online_chat_model = "gpt-4-turbo-preview"
|
||||
|
||||
empty_config = {
|
||||
|
||||
@@ -91,7 +91,7 @@ class OpenAIProcessorConfig(ConfigBase):
|
||||
|
||||
class OfflineChatProcessorConfig(ConfigBase):
|
||||
enable_offline_chat: Optional[bool] = False
|
||||
chat_model: Optional[str] = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
|
||||
Reference in New Issue
Block a user