Improve Offline Chat Model Experience (#494)

- Make offline chat model user configurable. Use `filename` of any [GPT4All supported  model](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models.json) like below:
- Run GPT4All Chat Model on GPU, when available via [GPT4All Vulcan support](https://blog.nomic.ai/posts/gpt4all-gpu-inference-with-vulkan)
- Use default Llama 2 supported by GPT4All
- Make `tokenizer` and `max-prompt-size` of chat model user configurable. E.g When using chat models not in [this pre-defined list](https://github.com/khoj-ai/khoj/blob/master/src/khoj/processor/conversation/utils.py) that support larger context window or a different tokenizer.

Closes #406, #418
This commit is contained in:
Debanjum
2023-10-16 17:44:49 -07:00
committed by GitHub
16 changed files with 230 additions and 141 deletions

View File

@@ -19,7 +19,7 @@ from khoj.utils.config import (
) )
from khoj.utils.helpers import resolve_absolute_path, merge_dicts from khoj.utils.helpers import resolve_absolute_path, merge_dicts
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, ConversationProcessorConfig from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.routers.indexer import configure_content, load_content, configure_search
@@ -168,9 +168,7 @@ def configure_conversation_processor(
conversation_config=ConversationProcessorConfig( conversation_config=ConversationProcessorConfig(
conversation_logfile=conversation_logfile, conversation_logfile=conversation_logfile,
openai=(conversation_config.openai if (conversation_config is not None) else None), openai=(conversation_config.openai if (conversation_config is not None) else None),
enable_offline_chat=( offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
conversation_config.enable_offline_chat if (conversation_config is not None) else False
),
) )
) )
else: else:

View File

@@ -236,7 +236,7 @@
</h3> </h3>
</div> </div>
<div class="card-description-row"> <div class="card-description-row">
<p class="card-description">Setup chat using OpenAI</p> <p class="card-description">Setup online chat using OpenAI</p>
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
<a class="card-button" href="/config/processor/conversation/openai"> <a class="card-button" href="/config/processor/conversation/openai">
@@ -261,21 +261,21 @@
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat"> <img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title"> <h3 class="card-title">
Offline Chat Offline Chat
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured"> <img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %} {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected."> <img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
{% endif %} {% endif %}
</h3> </h3>
</div> </div>
<div class="card-description-row"> <div class="card-description-row">
<p class="card-description">Setup offline chat (Llama V2)</p> <p class="card-description">Setup offline chat</p>
</div> </div>
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}"> <div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
<button class="card-button" onclick="toggleEnableLocalLLLM(false)"> <button class="card-button" onclick="toggleEnableLocalLLLM(false)">
Disable Disable
</button> </button>
</div> </div>
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}disabled{% else %}enabled{% endif %}"> <div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)"> <button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
Enable Enable
</button> </button>
@@ -346,7 +346,7 @@
featuresHintText.classList.add("show"); featuresHintText.classList.add("show");
} }
fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, { fetch('/api/config/data/processor/conversation/offline_chat' + '?enable_offline_chat=' + enable, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@@ -0,0 +1,83 @@
"""
Current format of khoj.yml
---
app:
...
content-type:
...
processor:
conversation:
enable-offline-chat: false
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
openai:
...
search-type:
...
New format of khoj.yml
---
app:
...
content-type:
...
processor:
conversation:
offline-chat:
enable-offline-chat: false
chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
tokenizer: null
max_prompt_size: null
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
openai:
...
search-type:
...
"""
import logging
from packaging import version
from khoj.utils.yaml import load_config_from_file, save_config_to_file
logger = logging.getLogger(__name__)
def migrate_offline_chat_schema(args):
schema_version = "0.12.3"
raw_config = load_config_from_file(args.config_file)
previous_version = raw_config.get("version")
if "processor" not in raw_config:
return args
if raw_config["processor"] is None:
return args
if "conversation" not in raw_config["processor"]:
return args
if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"):
logger.info(
f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration"
)
raw_config["version"] = schema_version
# Create max-prompt-size field in conversation processor schema
raw_config["processor"]["conversation"]["max-prompt-size"] = None
raw_config["processor"]["conversation"]["tokenizer"] = None
# Create offline chat schema based on existing enable_offline_chat field in khoj config schema
offline_chat_model = (
raw_config["processor"]["conversation"]
.get("offline-chat", {})
.get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin")
)
raw_config["processor"]["conversation"]["offline-chat"] = {
"enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False),
"chat-model": offline_chat_model,
}
# Delete old enable-offline-chat field from conversation processor schema
if "enable-offline-chat" in raw_config["processor"]["conversation"]:
del raw_config["processor"]["conversation"]["enable-offline-chat"]
save_config_to_file(raw_config, args.config_file)
return args

View File

@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
def extract_questions_offline( def extract_questions_offline(
text: str, text: str,
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
conversation_log={}, conversation_log={},
use_history: bool = True, use_history: bool = True,
@@ -113,7 +113,7 @@ def filter_questions(questions: List[str]):
] ]
filtered_questions = [] filtered_questions = []
for q in questions: for q in questions:
if not any([word in q.lower() for word in hint_words]): if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
filtered_questions.append(q) filtered_questions.append(q)
return filtered_questions return filtered_questions
@@ -123,10 +123,12 @@ def converse_offline(
references, references,
user_query, user_query,
conversation_log={}, conversation_log={},
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
completion_func=None, completion_func=None,
conversation_command=ConversationCommand.Default, conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Union[ThreadedGenerator, Iterator[str]]:
""" """
Converse with user using Llama Converse with user using Llama
@@ -158,6 +160,8 @@ def converse_offline(
prompts.system_prompt_message_llamav2, prompts.system_prompt_message_llamav2,
conversation_log, conversation_log,
model_name=model, model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
) )
g = ThreadedGenerator(references, completion_func=completion_func) g = ThreadedGenerator(references, completion_func=completion_func)

View File

@@ -1,3 +0,0 @@
model_name_to_url = {
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
}

View File

@@ -1,24 +1,8 @@
import os
import logging import logging
import requests
import hashlib
from tqdm import tqdm
from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
def get_md5_checksum(filename: str):
hash_md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def download_model(model_name: str): def download_model(model_name: str):
try: try:
@@ -27,57 +11,12 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e raise e
url = model_metadata.model_name_to_url.get(model_name) # Use GPU for Chat Model, if available
model_path = os.path.expanduser(f"~/.cache/gpt4all/")
if not url:
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
return GPT4All(model_name=model_name, model_path=model_path)
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
if os.path.exists(filename):
# Check if the user is connected to the internet
try: try:
requests.get("https://www.google.com/", timeout=5) model = GPT4All(model_name=model_name, device="gpu")
except: logger.debug("Loaded chat model to GPU.")
logger.debug("User is offline. Disabling allowed download flag") except ValueError:
return GPT4All(model_name=model_name, model_path=model_path, allow_download=False) model = GPT4All(model_name=model_name)
return GPT4All(model_name=model_name, model_path=model_path) logger.debug("Loaded chat model to CPU.")
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file return model
tmp_filename = filename + ".tmp"
try:
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
with open(tmp_filename, "wb") as f, tqdm(
unit="B", # unit string to be displayed.
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
unit_divisor=1024, # is used when unit_scale is true
total=total_size, # the total iteration.
desc=model_name, # prefix to be displayed on progress bar.
) as progress_bar:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
progress_bar.update(len(chunk))
# Verify the checksum
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
logger.error(
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
)
os.remove(tmp_filename)
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
# Move the tmp file to the actual file
os.rename(tmp_filename, filename)
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
return GPT4All(model_name)
except Exception as e:
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
# Remove the tmp file if it exists
if os.path.exists(tmp_filename):
os.remove(tmp_filename)
return None

View File

@@ -116,6 +116,8 @@ def converse(
temperature: float = 0.2, temperature: float = 0.2,
completion_func=None, completion_func=None,
conversation_command=ConversationCommand.Default, conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
): ):
""" """
Converse with user using OpenAI's ChatGPT Converse with user using OpenAI's ChatGPT
@@ -141,6 +143,8 @@ def converse(
prompts.personality.format(), prompts.personality.format(),
conversation_log, conversation_log,
model, model,
max_prompt_size,
tokenizer_name,
) )
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}") logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View File

@@ -23,7 +23,7 @@ no_notes_found = PromptTemplate.from_template(
""".strip() """.strip()
) )
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant. system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
Using your general knowledge and our past conversations as context, answer the following question. Using your general knowledge and our past conversations as context, answer the following question.
If you do not know the answer, say 'I don't know.'""" If you do not know the answer, say 'I don't know.'"""
@@ -51,13 +51,13 @@ extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
general_conversation_llamav2 = PromptTemplate.from_template( general_conversation_llamav2 = PromptTemplate.from_template(
""" """
<s>[INST]{query}[/INST] <s>[INST] {query} [/INST]
""".strip() """.strip()
) )
chat_history_llamav2_from_user = PromptTemplate.from_template( chat_history_llamav2_from_user = PromptTemplate.from_template(
""" """
<s>[INST]{message}[/INST] <s>[INST] {message} [/INST]
""".strip() """.strip()
) )
@@ -69,7 +69,7 @@ chat_history_llamav2_from_assistant = PromptTemplate.from_template(
conversation_llamav2 = PromptTemplate.from_template( conversation_llamav2 = PromptTemplate.from_template(
""" """
<s>[INST]{query}[/INST] <s>[INST] {query} [/INST]
""".strip() """.strip()
) )
@@ -91,7 +91,7 @@ Question: {query}
notes_conversation_llamav2 = PromptTemplate.from_template( notes_conversation_llamav2 = PromptTemplate.from_template(
""" """
Notes: User's Notes:
{references} {references}
Question: {query} Question: {query}
""".strip() """.strip()
@@ -134,19 +134,25 @@ Answer (in second person):"""
extract_questions_llamav2_sample = PromptTemplate.from_template( extract_questions_llamav2_sample = PromptTemplate.from_template(
""" """
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s> <s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
<s>[INST]How was my trip to Cambodia?[/INST][]</s> <s>[INST] How was my trip to Cambodia? [/INST]
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s> How was my trip to Cambodia?</s>
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s> <s>[INST] Who did I visit the temple with on that trip? [/INST]
<s>[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s> Who did I visit the temple with in Cambodia?</s>
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s> <s>[INST] How should I take care of my plants? [/INST]
<s>[INST]How are you feeling today?[/INST]</s> What kind of plants do I have? What issues do my plants have?</s>
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s> <s>[INST] How many tennis balls fit in the back of a 2002 Honda Civic? [/INST]
<s>[INST]<<SYS>> What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
<s>[INST] What did I do for Christmas last year? [/INST]
What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
<s>[INST] How are you feeling today? [/INST]</s>
<s>[INST] Is Alice older than Bob? [/INST]
When was Alice born? What is Bob's age?</s>
<s>[INST] <<SYS>>
Use these notes from the user's previous conversations to provide a response: Use these notes from the user's previous conversations to provide a response:
{chat_history} {chat_history}
<</SYS>>[/INST]</s> <</SYS>> [/INST]</s>
<s>[INST]{query}[/INST] <s>[INST] {query} [/INST]
""" """
) )

View File

@@ -3,24 +3,27 @@ import logging
from time import perf_counter from time import perf_counter
import json import json
from datetime import datetime from datetime import datetime
import queue
import tiktoken import tiktoken
# External packages # External packages
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from transformers import LlamaTokenizerFast from transformers import AutoTokenizer
# Internal Packages # Internal Packages
import queue
from khoj.utils.helpers import merge_dicts from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
max_prompt_size = { model_to_prompt_size = {
"gpt-3.5-turbo": 4096, "gpt-3.5-turbo": 4096,
"gpt-4": 8192, "gpt-4": 8192,
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
"gpt-3.5-turbo-16k": 15000, "gpt-3.5-turbo-16k": 15000,
} }
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"} model_to_tokenizer = {
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
}
class ThreadedGenerator: class ThreadedGenerator:
@@ -82,9 +85,26 @@ def message_to_log(
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2 user_message,
system_message,
conversation_log={},
model_name="gpt-3.5-turbo",
max_prompt_size=None,
tokenizer_name=None,
): ):
"""Generate messages for ChatGPT with context from previous conversation""" """Generate messages for ChatGPT with context from previous conversation"""
# Set max prompt size from user config, pre-configured for model or to default prompt size
try:
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
except:
max_prompt_size = 2000
logger.warning(
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
)
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
# Extract Chat History for Context # Extract Chat History for Context
chat_logs = [] chat_logs = []
for chat in conversation_log.get("chat", []): for chat in conversation_log.get("chat", []):
@@ -105,19 +125,28 @@ def generate_chatml_messages_with_context(
messages = user_chatml_message + rest_backnforths + system_chatml_message messages = user_chatml_message + rest_backnforths + system_chatml_message
# Truncate oldest messages from conversation history until under max supported prompt size by model # Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size[model_name], model_name) messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return message in chronological order # Return message in chronological order
return messages[::-1] return messages[::-1]
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]: def truncate_messages(
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
if "llama" in model_name: try:
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name]) if model_name.startswith("gpt-"):
else:
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
except:
default_tokenizer = "hf-internal-testing/llama-tokenizer"
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
logger.warning(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
system_message = messages.pop() system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content)) system_message_tokens = len(encoder.encode(system_message.content))

View File

@@ -284,10 +284,11 @@ if not state.demo:
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200) @api.post("/config/data/processor/conversation/offline_chat", status_code=200)
async def set_processor_enable_offline_chat_config_data( async def set_processor_enable_offline_chat_config_data(
request: Request, request: Request,
enable_offline_chat: bool, enable_offline_chat: bool,
offline_chat_model: Optional[str] = None,
client: Optional[str] = None, client: Optional[str] = None,
): ):
_initialize_config() _initialize_config()
@@ -301,7 +302,9 @@ if not state.demo:
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
assert state.config.processor.conversation is not None assert state.config.processor.conversation is not None
state.config.processor.conversation.enable_offline_chat = enable_offline_chat state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
if offline_chat_model is not None:
state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
state.processor_config = configure_processor(state.config.processor, state.processor_config) state.processor_config = configure_processor(state.config.processor, state.processor_config)
update_telemetry_state( update_telemetry_state(
@@ -713,7 +716,7 @@ async def chat(
conversation_command = ConversationCommand.General conversation_command = ConversationCommand.General
if conversation_command == ConversationCommand.Help: if conversation_command == ConversationCommand.Help:
model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai" model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
formatted_help = help_message.format(model=model_type, version=state.khoj_version) formatted_help = help_message.format(model=model_type, version=state.khoj_version)
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
@@ -788,7 +791,7 @@ async def extract_references_and_questions(
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if state.processor_config.conversation.enable_offline_chat: if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_offline( inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
@@ -804,7 +807,7 @@ async def extract_references_and_questions(
with timer("Searching knowledge base took", logger): with timer("Searching knowledge base took", logger):
result_list = [] result_list = []
for query in inferred_queries: for query in inferred_queries:
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
result_list.extend( result_list.extend(
await search( await search(
f"{query} {filters_in_query}", f"{query} {filters_in_query}",

View File

@@ -113,7 +113,7 @@ def generate_chat_response(
meta_log=meta_log, meta_log=meta_log,
) )
if state.processor_config.conversation.enable_offline_chat: if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
chat_response = converse_offline( chat_response = converse_offline(
references=compiled_references, references=compiled_references,
@@ -122,6 +122,9 @@ def generate_chat_response(
conversation_log=meta_log, conversation_log=meta_log,
completion_func=partial_completion, completion_func=partial_completion,
conversation_command=conversation_command, conversation_command=conversation_command,
model=state.processor_config.conversation.offline_chat.chat_model,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
) )
elif state.processor_config.conversation.openai_model: elif state.processor_config.conversation.openai_model:
@@ -135,6 +138,8 @@ def generate_chat_response(
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
conversation_command=conversation_command, conversation_command=conversation_command,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
) )
except Exception as e: except Exception as e:

View File

@@ -9,6 +9,7 @@ from khoj.utils.yaml import parse_config_from_file
from khoj.migrations.migrate_version import migrate_config_to_version from khoj.migrations.migrate_version import migrate_config_to_version
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
from khoj.migrations.migrate_offline_model import migrate_offline_model from khoj.migrations.migrate_offline_model import migrate_offline_model
from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema
def cli(args=None): def cli(args=None):
@@ -55,7 +56,12 @@ def cli(args=None):
def run_migrations(args): def run_migrations(args):
migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model] migrations = [
migrate_config_to_version,
migrate_processor_conversation_schema,
migrate_offline_model,
migrate_offline_chat_schema,
]
for migration in migrations: for migration in migrations:
args = migration(args) args = migration(args)
return args return args

View File

@@ -84,7 +84,6 @@ class SearchModels:
@dataclass @dataclass
class GPT4AllProcessorConfig: class GPT4AllProcessorConfig:
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
loaded_model: Union[Any, None] = None loaded_model: Union[Any, None] = None
@@ -95,18 +94,20 @@ class ConversationProcessorConfigModel:
): ):
self.openai_model = conversation_config.openai self.openai_model = conversation_config.openai
self.gpt4all_model = GPT4AllProcessorConfig() self.gpt4all_model = GPT4AllProcessorConfig()
self.enable_offline_chat = conversation_config.enable_offline_chat self.offline_chat = conversation_config.offline_chat
self.max_prompt_size = conversation_config.max_prompt_size
self.tokenizer = conversation_config.tokenizer
self.conversation_logfile = Path(conversation_config.conversation_logfile) self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = [] self.chat_session: List[str] = []
self.meta_log: dict = {} self.meta_log: dict = {}
if self.enable_offline_chat: if self.offline_chat.enable_offline_chat:
try: try:
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model) self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
except ValueError as e: except ValueError as e:
self.offline_chat.enable_offline_chat = False
self.gpt4all_model.loaded_model = None self.gpt4all_model.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True) logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
self.enable_offline_chat = False
else: else:
self.gpt4all_model.loaded_model = None self.gpt4all_model.loaded_model = None

View File

@@ -91,10 +91,17 @@ class OpenAIProcessorConfig(ConfigBase):
chat_model: Optional[str] = "gpt-3.5-turbo" chat_model: Optional[str] = "gpt-3.5-turbo"
class OfflineChatProcessorConfig(ConfigBase):
enable_offline_chat: Optional[bool] = False
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin"
class ConversationProcessorConfig(ConfigBase): class ConversationProcessorConfig(ConfigBase):
conversation_logfile: Path conversation_logfile: Path
openai: Optional[OpenAIProcessorConfig] openai: Optional[OpenAIProcessorConfig]
enable_offline_chat: Optional[bool] = False offline_chat: Optional[OfflineChatProcessorConfig]
max_prompt_size: Optional[int]
tokenizer: Optional[str]
class ProcessorConfig(ConfigBase): class ProcessorConfig(ConfigBase):

View File

@@ -16,6 +16,7 @@ from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ContentConfig, ContentConfig,
ConversationProcessorConfig, ConversationProcessorConfig,
OfflineChatProcessorConfig,
OpenAIProcessorConfig, OpenAIProcessorConfig,
ProcessorConfig, ProcessorConfig,
TextContentConfig, TextContentConfig,
@@ -205,8 +206,9 @@ def processor_config_offline_chat(tmp_path_factory):
# Setup conversation processor # Setup conversation processor
processor_config = ProcessorConfig() processor_config = ProcessorConfig()
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
processor_config.conversation = ConversationProcessorConfig( processor_config.conversation = ConversationProcessorConfig(
enable_offline_chat=True, offline_chat=offline_chat,
conversation_logfile=processor_dir.joinpath("conversation_logs.json"), conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
) )

View File

@@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@@ -128,15 +128,15 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
@pytest.mark.chatquality @pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model): def test_extract_multiple_implicit_questions_from_message(loaded_model):
# Act # Act
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model) response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
# Assert # Assert
expected_responses = ["height", "taller", "shorter", "heights"] expected_responses = ["height", "taller", "shorter", "heights", "who"]
assert len(response) <= 3 assert len(response) <= 3
for question in response: for question in response:
assert any([expected_response in question.lower() for expected_response in expected_responses]), ( assert any([expected_response in question.lower() for expected_response in expected_responses]), (
"Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question "Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
) )
@@ -145,7 +145,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model):
def test_generate_search_query_using_question_from_chat_history(loaded_model): def test_generate_search_query_using_question_from_chat_history(loaded_model):
# Arrange # Arrange
message_list = [ message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []), ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
] ]
# Act # Act
@@ -156,17 +156,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
use_history=True, use_history=True,
) )
expected_responses = [ all_expected_in_response = [
"Vader", "Anderson",
"sons", ]
any_expected_in_response = [
"son", "son",
"Darth", "sons",
"children", "children",
] ]
# Assert # Assert
assert len(response) >= 1 assert len(response) >= 1
assert any([expected_response in response[0] for expected_response in expected_responses]), ( assert all([expected_response in response[0] for expected_response in all_expected_in_response]), (
"Expected chat actor to ask for clarification in response, but got: " + response[0]
)
assert any([expected_response in response[0] for expected_response in any_expected_in_response]), (
"Expected chat actor to ask for clarification in response, but got: " + response[0] "Expected chat actor to ask for clarification in response, but got: " + response[0]
) )
@@ -176,20 +181,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
def test_generate_search_query_using_answer_from_chat_history(loaded_model): def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Arrange # Arrange
message_list = [ message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []), ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
] ]
# Act # Act
response = extract_questions_offline( response = extract_questions_offline(
"Is she a Jedi?", "Is she a Doctor?",
conversation_log=populate_chat_history(message_list), conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model, loaded_model=loaded_model,
use_history=True, use_history=True,
) )
expected_responses = [ expected_responses = [
"Leia", "Barbara",
"Vader", "Robert",
"daughter", "daughter",
] ]