mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Improve Offline Chat Model Experience (#494)
- Make offline chat model user configurable. Use `filename` of any [GPT4All supported model](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models.json) like below: - Run GPT4All Chat Model on GPU, when available via [GPT4All Vulcan support](https://blog.nomic.ai/posts/gpt4all-gpu-inference-with-vulkan) - Use default Llama 2 supported by GPT4All - Make `tokenizer` and `max-prompt-size` of chat model user configurable. E.g When using chat models not in [this pre-defined list](https://github.com/khoj-ai/khoj/blob/master/src/khoj/processor/conversation/utils.py) that support larger context window or a different tokenizer. Closes #406, #418
This commit is contained in:
@@ -19,7 +19,7 @@ from khoj.utils.config import (
|
|||||||
)
|
)
|
||||||
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
|
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
|
||||||
from khoj.utils.fs_syncer import collect_files
|
from khoj.utils.fs_syncer import collect_files
|
||||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, ConversationProcessorConfig
|
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
|
||||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||||
|
|
||||||
|
|
||||||
@@ -168,9 +168,7 @@ def configure_conversation_processor(
|
|||||||
conversation_config=ConversationProcessorConfig(
|
conversation_config=ConversationProcessorConfig(
|
||||||
conversation_logfile=conversation_logfile,
|
conversation_logfile=conversation_logfile,
|
||||||
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
||||||
enable_offline_chat=(
|
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
|
||||||
conversation_config.enable_offline_chat if (conversation_config is not None) else False
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -236,7 +236,7 @@
|
|||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Setup chat using OpenAI</p>
|
<p class="card-description">Setup online chat using OpenAI</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/processor/conversation/openai">
|
<a class="card-button" href="/config/processor/conversation/openai">
|
||||||
@@ -261,21 +261,21 @@
|
|||||||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Offline Chat
|
Offline Chat
|
||||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %}
|
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
|
||||||
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
|
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Setup offline chat (Llama V2)</p>
|
<p class="card-description">Setup offline chat</p>
|
||||||
</div>
|
</div>
|
||||||
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
||||||
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
||||||
Disable
|
Disable
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
|
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
|
||||||
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
|
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
|
||||||
Enable
|
Enable
|
||||||
</button>
|
</button>
|
||||||
@@ -346,7 +346,7 @@
|
|||||||
featuresHintText.classList.add("show");
|
featuresHintText.classList.add("show");
|
||||||
}
|
}
|
||||||
|
|
||||||
fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, {
|
fetch('/api/config/data/processor/conversation/offline_chat' + '?enable_offline_chat=' + enable, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|||||||
83
src/khoj/migrations/migrate_offline_chat_schema.py
Normal file
83
src/khoj/migrations/migrate_offline_chat_schema.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
Current format of khoj.yml
|
||||||
|
---
|
||||||
|
app:
|
||||||
|
...
|
||||||
|
content-type:
|
||||||
|
...
|
||||||
|
processor:
|
||||||
|
conversation:
|
||||||
|
enable-offline-chat: false
|
||||||
|
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||||
|
openai:
|
||||||
|
...
|
||||||
|
search-type:
|
||||||
|
...
|
||||||
|
|
||||||
|
New format of khoj.yml
|
||||||
|
---
|
||||||
|
app:
|
||||||
|
...
|
||||||
|
content-type:
|
||||||
|
...
|
||||||
|
processor:
|
||||||
|
conversation:
|
||||||
|
offline-chat:
|
||||||
|
enable-offline-chat: false
|
||||||
|
chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
|
||||||
|
tokenizer: null
|
||||||
|
max_prompt_size: null
|
||||||
|
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||||
|
openai:
|
||||||
|
...
|
||||||
|
search-type:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_offline_chat_schema(args):
|
||||||
|
schema_version = "0.12.3"
|
||||||
|
raw_config = load_config_from_file(args.config_file)
|
||||||
|
previous_version = raw_config.get("version")
|
||||||
|
|
||||||
|
if "processor" not in raw_config:
|
||||||
|
return args
|
||||||
|
if raw_config["processor"] is None:
|
||||||
|
return args
|
||||||
|
if "conversation" not in raw_config["processor"]:
|
||||||
|
return args
|
||||||
|
|
||||||
|
if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"):
|
||||||
|
logger.info(
|
||||||
|
f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration"
|
||||||
|
)
|
||||||
|
raw_config["version"] = schema_version
|
||||||
|
|
||||||
|
# Create max-prompt-size field in conversation processor schema
|
||||||
|
raw_config["processor"]["conversation"]["max-prompt-size"] = None
|
||||||
|
raw_config["processor"]["conversation"]["tokenizer"] = None
|
||||||
|
|
||||||
|
# Create offline chat schema based on existing enable_offline_chat field in khoj config schema
|
||||||
|
offline_chat_model = (
|
||||||
|
raw_config["processor"]["conversation"]
|
||||||
|
.get("offline-chat", {})
|
||||||
|
.get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin")
|
||||||
|
)
|
||||||
|
raw_config["processor"]["conversation"]["offline-chat"] = {
|
||||||
|
"enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False),
|
||||||
|
"chat-model": offline_chat_model,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Delete old enable-offline-chat field from conversation processor schema
|
||||||
|
if "enable-offline-chat" in raw_config["processor"]["conversation"]:
|
||||||
|
del raw_config["processor"]["conversation"]["enable-offline-chat"]
|
||||||
|
|
||||||
|
save_config_to_file(raw_config, args.config_file)
|
||||||
|
return args
|
||||||
@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def extract_questions_offline(
|
def extract_questions_offline(
|
||||||
text: str,
|
text: str,
|
||||||
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
|
model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
use_history: bool = True,
|
use_history: bool = True,
|
||||||
@@ -113,7 +113,7 @@ def filter_questions(questions: List[str]):
|
|||||||
]
|
]
|
||||||
filtered_questions = []
|
filtered_questions = []
|
||||||
for q in questions:
|
for q in questions:
|
||||||
if not any([word in q.lower() for word in hint_words]):
|
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
|
||||||
filtered_questions.append(q)
|
filtered_questions.append(q)
|
||||||
|
|
||||||
return filtered_questions
|
return filtered_questions
|
||||||
@@ -123,10 +123,12 @@ def converse_offline(
|
|||||||
references,
|
references,
|
||||||
user_query,
|
user_query,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
|
model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_command=ConversationCommand.Default,
|
conversation_command=ConversationCommand.Default,
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Llama
|
Converse with user using Llama
|
||||||
@@ -158,6 +160,8 @@ def converse_offline(
|
|||||||
prompts.system_prompt_message_llamav2,
|
prompts.system_prompt_message_llamav2,
|
||||||
conversation_log,
|
conversation_log,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
tokenizer_name=tokenizer_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
g = ThreadedGenerator(references, completion_func=completion_func)
|
g = ThreadedGenerator(references, completion_func=completion_func)
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
model_name_to_url = {
|
|
||||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
|
||||||
}
|
|
||||||
@@ -1,24 +1,8 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import requests
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from khoj.processor.conversation.gpt4all import model_metadata
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
|
|
||||||
|
|
||||||
|
|
||||||
def get_md5_checksum(filename: str):
|
|
||||||
hash_md5 = hashlib.md5()
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
for chunk in iter(lambda: f.read(8192), b""):
|
|
||||||
hash_md5.update(chunk)
|
|
||||||
return hash_md5.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def download_model(model_name: str):
|
def download_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
@@ -27,57 +11,12 @@ def download_model(model_name: str):
|
|||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
url = model_metadata.model_name_to_url.get(model_name)
|
# Use GPU for Chat Model, if available
|
||||||
model_path = os.path.expanduser(f"~/.cache/gpt4all/")
|
|
||||||
if not url:
|
|
||||||
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
|
|
||||||
return GPT4All(model_name=model_name, model_path=model_path)
|
|
||||||
|
|
||||||
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
|
|
||||||
if os.path.exists(filename):
|
|
||||||
# Check if the user is connected to the internet
|
|
||||||
try:
|
|
||||||
requests.get("https://www.google.com/", timeout=5)
|
|
||||||
except:
|
|
||||||
logger.debug("User is offline. Disabling allowed download flag")
|
|
||||||
return GPT4All(model_name=model_name, model_path=model_path, allow_download=False)
|
|
||||||
return GPT4All(model_name=model_name, model_path=model_path)
|
|
||||||
|
|
||||||
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
|
|
||||||
tmp_filename = filename + ".tmp"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
|
model = GPT4All(model_name=model_name, device="gpu")
|
||||||
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
|
logger.debug("Loaded chat model to GPU.")
|
||||||
with requests.get(url, stream=True) as r:
|
except ValueError:
|
||||||
r.raise_for_status()
|
model = GPT4All(model_name=model_name)
|
||||||
total_size = int(r.headers.get("content-length", 0))
|
logger.debug("Loaded chat model to CPU.")
|
||||||
with open(tmp_filename, "wb") as f, tqdm(
|
|
||||||
unit="B", # unit string to be displayed.
|
|
||||||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
|
||||||
unit_divisor=1024, # is used when unit_scale is true
|
|
||||||
total=total_size, # the total iteration.
|
|
||||||
desc=model_name, # prefix to be displayed on progress bar.
|
|
||||||
) as progress_bar:
|
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
|
||||||
f.write(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
|
|
||||||
# Verify the checksum
|
return model
|
||||||
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
|
|
||||||
logger.error(
|
|
||||||
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
|
|
||||||
)
|
|
||||||
os.remove(tmp_filename)
|
|
||||||
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
|
|
||||||
|
|
||||||
# Move the tmp file to the actual file
|
|
||||||
os.rename(tmp_filename, filename)
|
|
||||||
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
|
|
||||||
return GPT4All(model_name)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
|
|
||||||
# Remove the tmp file if it exists
|
|
||||||
if os.path.exists(tmp_filename):
|
|
||||||
os.remove(tmp_filename)
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -116,6 +116,8 @@ def converse(
|
|||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_command=ConversationCommand.Default,
|
conversation_command=ConversationCommand.Default,
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
@@ -141,6 +143,8 @@ def converse(
|
|||||||
prompts.personality.format(),
|
prompts.personality.format(),
|
||||||
conversation_log,
|
conversation_log,
|
||||||
model,
|
model,
|
||||||
|
max_prompt_size,
|
||||||
|
tokenizer_name,
|
||||||
)
|
)
|
||||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ no_notes_found = PromptTemplate.from_template(
|
|||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
|
system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||||
Using your general knowledge and our past conversations as context, answer the following question.
|
Using your general knowledge and our past conversations as context, answer the following question.
|
||||||
If you do not know the answer, say 'I don't know.'"""
|
If you do not know the answer, say 'I don't know.'"""
|
||||||
|
|
||||||
@@ -51,13 +51,13 @@ extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
|
|||||||
|
|
||||||
general_conversation_llamav2 = PromptTemplate.from_template(
|
general_conversation_llamav2 = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
<s>[INST]{query}[/INST]
|
<s>[INST] {query} [/INST]
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history_llamav2_from_user = PromptTemplate.from_template(
|
chat_history_llamav2_from_user = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
<s>[INST]{message}[/INST]
|
<s>[INST] {message} [/INST]
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -69,7 +69,7 @@ chat_history_llamav2_from_assistant = PromptTemplate.from_template(
|
|||||||
|
|
||||||
conversation_llamav2 = PromptTemplate.from_template(
|
conversation_llamav2 = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
<s>[INST]{query}[/INST]
|
<s>[INST] {query} [/INST]
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ Question: {query}
|
|||||||
|
|
||||||
notes_conversation_llamav2 = PromptTemplate.from_template(
|
notes_conversation_llamav2 = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
Notes:
|
User's Notes:
|
||||||
{references}
|
{references}
|
||||||
Question: {query}
|
Question: {query}
|
||||||
""".strip()
|
""".strip()
|
||||||
@@ -134,19 +134,25 @@ Answer (in second person):"""
|
|||||||
|
|
||||||
extract_questions_llamav2_sample = PromptTemplate.from_template(
|
extract_questions_llamav2_sample = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
|
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
|
||||||
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
|
<s>[INST] How was my trip to Cambodia? [/INST]
|
||||||
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
|
How was my trip to Cambodia?</s>
|
||||||
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
|
<s>[INST] Who did I visit the temple with on that trip? [/INST]
|
||||||
<s>[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
|
Who did I visit the temple with in Cambodia?</s>
|
||||||
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
|
<s>[INST] How should I take care of my plants? [/INST]
|
||||||
<s>[INST]How are you feeling today?[/INST]</s>
|
What kind of plants do I have? What issues do my plants have?</s>
|
||||||
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
|
<s>[INST] How many tennis balls fit in the back of a 2002 Honda Civic? [/INST]
|
||||||
<s>[INST]<<SYS>>
|
What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
|
||||||
|
<s>[INST] What did I do for Christmas last year? [/INST]
|
||||||
|
What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
|
||||||
|
<s>[INST] How are you feeling today? [/INST]</s>
|
||||||
|
<s>[INST] Is Alice older than Bob? [/INST]
|
||||||
|
When was Alice born? What is Bob's age?</s>
|
||||||
|
<s>[INST] <<SYS>>
|
||||||
Use these notes from the user's previous conversations to provide a response:
|
Use these notes from the user's previous conversations to provide a response:
|
||||||
{chat_history}
|
{chat_history}
|
||||||
<</SYS>>[/INST]</s>
|
<</SYS>> [/INST]</s>
|
||||||
<s>[INST]{query}[/INST]
|
<s>[INST] {query} [/INST]
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,24 +3,27 @@ import logging
|
|||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import queue
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
# External packages
|
# External packages
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from transformers import LlamaTokenizerFast
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
import queue
|
|
||||||
from khoj.utils.helpers import merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
max_prompt_size = {
|
model_to_prompt_size = {
|
||||||
"gpt-3.5-turbo": 4096,
|
"gpt-3.5-turbo": 4096,
|
||||||
"gpt-4": 8192,
|
"gpt-4": 8192,
|
||||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548,
|
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
||||||
"gpt-3.5-turbo-16k": 15000,
|
"gpt-3.5-turbo-16k": 15000,
|
||||||
}
|
}
|
||||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
model_to_tokenizer = {
|
||||||
|
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ThreadedGenerator:
|
class ThreadedGenerator:
|
||||||
@@ -82,9 +85,26 @@ def message_to_log(
|
|||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
|
user_message,
|
||||||
|
system_message,
|
||||||
|
conversation_log={},
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
):
|
):
|
||||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||||
|
# Set max prompt size from user config, pre-configured for model or to default prompt size
|
||||||
|
try:
|
||||||
|
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
|
||||||
|
except:
|
||||||
|
max_prompt_size = 2000
|
||||||
|
logger.warning(
|
||||||
|
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale lookback turns proportional to max prompt size supported by model
|
||||||
|
lookback_turns = max_prompt_size // 750
|
||||||
|
|
||||||
# Extract Chat History for Context
|
# Extract Chat History for Context
|
||||||
chat_logs = []
|
chat_logs = []
|
||||||
for chat in conversation_log.get("chat", []):
|
for chat in conversation_log.get("chat", []):
|
||||||
@@ -105,19 +125,28 @@ def generate_chatml_messages_with_context(
|
|||||||
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
||||||
|
|
||||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||||
messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
|
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||||
|
|
||||||
# Return message in chronological order
|
# Return message in chronological order
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
|
def truncate_messages(
|
||||||
|
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
|
||||||
|
) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
if "llama" in model_name:
|
try:
|
||||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
|
if model_name.startswith("gpt-"):
|
||||||
else:
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
else:
|
||||||
|
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||||
|
except:
|
||||||
|
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||||
|
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||||
|
logger.warning(
|
||||||
|
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||||
|
)
|
||||||
|
|
||||||
system_message = messages.pop()
|
system_message = messages.pop()
|
||||||
system_message_tokens = len(encoder.encode(system_message.content))
|
system_message_tokens = len(encoder.encode(system_message.content))
|
||||||
|
|||||||
@@ -284,10 +284,11 @@ if not state.demo:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
@api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200)
|
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
|
||||||
async def set_processor_enable_offline_chat_config_data(
|
async def set_processor_enable_offline_chat_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
enable_offline_chat: bool,
|
enable_offline_chat: bool,
|
||||||
|
offline_chat_model: Optional[str] = None,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
_initialize_config()
|
_initialize_config()
|
||||||
@@ -301,7 +302,9 @@ if not state.demo:
|
|||||||
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
||||||
|
|
||||||
assert state.config.processor.conversation is not None
|
assert state.config.processor.conversation is not None
|
||||||
state.config.processor.conversation.enable_offline_chat = enable_offline_chat
|
state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
|
||||||
|
if offline_chat_model is not None:
|
||||||
|
state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
|
||||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
@@ -713,7 +716,7 @@ async def chat(
|
|||||||
conversation_command = ConversationCommand.General
|
conversation_command = ConversationCommand.General
|
||||||
|
|
||||||
if conversation_command == ConversationCommand.Help:
|
if conversation_command == ConversationCommand.Help:
|
||||||
model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai"
|
model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
|
||||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
|
||||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||||
|
|
||||||
@@ -788,7 +791,7 @@ async def extract_references_and_questions(
|
|||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
if state.processor_config.conversation.enable_offline_chat:
|
if state.processor_config.conversation.offline_chat.enable_offline_chat:
|
||||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||||
inferred_queries = extract_questions_offline(
|
inferred_queries = extract_questions_offline(
|
||||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||||
@@ -804,7 +807,7 @@ async def extract_references_and_questions(
|
|||||||
with timer("Searching knowledge base took", logger):
|
with timer("Searching knowledge base took", logger):
|
||||||
result_list = []
|
result_list = []
|
||||||
for query in inferred_queries:
|
for query in inferred_queries:
|
||||||
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
|
n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
|
||||||
result_list.extend(
|
result_list.extend(
|
||||||
await search(
|
await search(
|
||||||
f"{query} {filters_in_query}",
|
f"{query} {filters_in_query}",
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ def generate_chat_response(
|
|||||||
meta_log=meta_log,
|
meta_log=meta_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.processor_config.conversation.enable_offline_chat:
|
if state.processor_config.conversation.offline_chat.enable_offline_chat:
|
||||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
@@ -122,6 +122,9 @@ def generate_chat_response(
|
|||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_command=conversation_command,
|
conversation_command=conversation_command,
|
||||||
|
model=state.processor_config.conversation.offline_chat.chat_model,
|
||||||
|
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||||
|
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif state.processor_config.conversation.openai_model:
|
elif state.processor_config.conversation.openai_model:
|
||||||
@@ -135,6 +138,8 @@ def generate_chat_response(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_command=conversation_command,
|
conversation_command=conversation_command,
|
||||||
|
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||||
|
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from khoj.utils.yaml import parse_config_from_file
|
|||||||
from khoj.migrations.migrate_version import migrate_config_to_version
|
from khoj.migrations.migrate_version import migrate_config_to_version
|
||||||
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
|
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
|
||||||
from khoj.migrations.migrate_offline_model import migrate_offline_model
|
from khoj.migrations.migrate_offline_model import migrate_offline_model
|
||||||
|
from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema
|
||||||
|
|
||||||
|
|
||||||
def cli(args=None):
|
def cli(args=None):
|
||||||
@@ -55,7 +56,12 @@ def cli(args=None):
|
|||||||
|
|
||||||
|
|
||||||
def run_migrations(args):
|
def run_migrations(args):
|
||||||
migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model]
|
migrations = [
|
||||||
|
migrate_config_to_version,
|
||||||
|
migrate_processor_conversation_schema,
|
||||||
|
migrate_offline_model,
|
||||||
|
migrate_offline_chat_schema,
|
||||||
|
]
|
||||||
for migration in migrations:
|
for migration in migrations:
|
||||||
args = migration(args)
|
args = migration(args)
|
||||||
return args
|
return args
|
||||||
|
|||||||
@@ -84,7 +84,6 @@ class SearchModels:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPT4AllProcessorConfig:
|
class GPT4AllProcessorConfig:
|
||||||
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
|
||||||
loaded_model: Union[Any, None] = None
|
loaded_model: Union[Any, None] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -95,18 +94,20 @@ class ConversationProcessorConfigModel:
|
|||||||
):
|
):
|
||||||
self.openai_model = conversation_config.openai
|
self.openai_model = conversation_config.openai
|
||||||
self.gpt4all_model = GPT4AllProcessorConfig()
|
self.gpt4all_model = GPT4AllProcessorConfig()
|
||||||
self.enable_offline_chat = conversation_config.enable_offline_chat
|
self.offline_chat = conversation_config.offline_chat
|
||||||
|
self.max_prompt_size = conversation_config.max_prompt_size
|
||||||
|
self.tokenizer = conversation_config.tokenizer
|
||||||
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
||||||
self.chat_session: List[str] = []
|
self.chat_session: List[str] = []
|
||||||
self.meta_log: dict = {}
|
self.meta_log: dict = {}
|
||||||
|
|
||||||
if self.enable_offline_chat:
|
if self.offline_chat.enable_offline_chat:
|
||||||
try:
|
try:
|
||||||
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
|
self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
self.offline_chat.enable_offline_chat = False
|
||||||
self.gpt4all_model.loaded_model = None
|
self.gpt4all_model.loaded_model = None
|
||||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||||
self.enable_offline_chat = False
|
|
||||||
else:
|
else:
|
||||||
self.gpt4all_model.loaded_model = None
|
self.gpt4all_model.loaded_model = None
|
||||||
|
|
||||||
|
|||||||
@@ -91,10 +91,17 @@ class OpenAIProcessorConfig(ConfigBase):
|
|||||||
chat_model: Optional[str] = "gpt-3.5-turbo"
|
chat_model: Optional[str] = "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
||||||
|
class OfflineChatProcessorConfig(ConfigBase):
|
||||||
|
enable_offline_chat: Optional[bool] = False
|
||||||
|
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin"
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfig(ConfigBase):
|
class ConversationProcessorConfig(ConfigBase):
|
||||||
conversation_logfile: Path
|
conversation_logfile: Path
|
||||||
openai: Optional[OpenAIProcessorConfig]
|
openai: Optional[OpenAIProcessorConfig]
|
||||||
enable_offline_chat: Optional[bool] = False
|
offline_chat: Optional[OfflineChatProcessorConfig]
|
||||||
|
max_prompt_size: Optional[int]
|
||||||
|
tokenizer: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class ProcessorConfig(ConfigBase):
|
class ProcessorConfig(ConfigBase):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from khoj.utils.helpers import resolve_absolute_path
|
|||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
ConversationProcessorConfig,
|
ConversationProcessorConfig,
|
||||||
|
OfflineChatProcessorConfig,
|
||||||
OpenAIProcessorConfig,
|
OpenAIProcessorConfig,
|
||||||
ProcessorConfig,
|
ProcessorConfig,
|
||||||
TextContentConfig,
|
TextContentConfig,
|
||||||
@@ -205,8 +206,9 @@ def processor_config_offline_chat(tmp_path_factory):
|
|||||||
|
|
||||||
# Setup conversation processor
|
# Setup conversation processor
|
||||||
processor_config = ProcessorConfig()
|
processor_config = ProcessorConfig()
|
||||||
|
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
|
||||||
processor_config.conversation = ConversationProcessorConfig(
|
processor_config.conversation = ConversationProcessorConfig(
|
||||||
enable_offline_chat=True,
|
offline_chat=offline_chat,
|
||||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model
|
|||||||
|
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
|
|
||||||
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@@ -128,15 +128,15 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
|||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||||
# Act
|
# Act
|
||||||
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
|
response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["height", "taller", "shorter", "heights"]
|
expected_responses = ["height", "taller", "shorter", "heights", "who"]
|
||||||
assert len(response) <= 3
|
assert len(response) <= 3
|
||||||
|
|
||||||
for question in response:
|
for question in response:
|
||||||
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
|
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
|
||||||
"Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question
|
"Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
|||||||
def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
|
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@@ -156,17 +156,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
|||||||
use_history=True,
|
use_history=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_responses = [
|
all_expected_in_response = [
|
||||||
"Vader",
|
"Anderson",
|
||||||
"sons",
|
]
|
||||||
|
|
||||||
|
any_expected_in_response = [
|
||||||
"son",
|
"son",
|
||||||
"Darth",
|
"sons",
|
||||||
"children",
|
"children",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(response) >= 1
|
assert len(response) >= 1
|
||||||
assert any([expected_response in response[0] for expected_response in expected_responses]), (
|
assert all([expected_response in response[0] for expected_response in all_expected_in_response]), (
|
||||||
|
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
||||||
|
)
|
||||||
|
assert any([expected_response in response[0] for expected_response in any_expected_in_response]), (
|
||||||
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -176,20 +181,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
|||||||
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
|
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = extract_questions_offline(
|
response = extract_questions_offline(
|
||||||
"Is she a Jedi?",
|
"Is she a Doctor?",
|
||||||
conversation_log=populate_chat_history(message_list),
|
conversation_log=populate_chat_history(message_list),
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
use_history=True,
|
use_history=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
"Leia",
|
"Barbara",
|
||||||
"Vader",
|
"Robert",
|
||||||
"daughter",
|
"daughter",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user