Replace Falcon 🦅 model with Llama V2 🦙 for offline chat (#352)

* Working example with LlamaV2 running locally on my machine

- Download from huggingface
- Plug in to GPT4All
- Update prompts to fit the llama format

* Add appropriate prompts for extracting questions based on a query based on llama format

* Rename Falcon to Llama and make some improvements to the extract_questions flow

* Do further tuning to extract question prompts and unit tests

* Disable extracting questions dynamically from Llama, as results are still unreliable
This commit is contained in:
sabaimran
2023-07-28 03:51:20 +00:00
committed by GitHub
parent 55965eea7d
commit 124d97c26d
11 changed files with 248 additions and 141 deletions

View File

@@ -229,7 +229,7 @@
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Setup offline chat (Falcon 7B)</p>
<p class="card-description">Setup offline chat (Llama V2)</p>
</div>
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">

View File

@@ -1,6 +1,5 @@
from typing import Union, List
from datetime import datetime
import sys
import logging
from threading import Thread
@@ -8,7 +7,6 @@ from langchain.schema import ChatMessage
from gpt4all import GPT4All
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
from khoj.processor.conversation import prompts
from khoj.utils.constants import empty_escape_sequences
@@ -16,20 +14,21 @@ from khoj.utils.constants import empty_escape_sequences
logger = logging.getLogger(__name__)
def extract_questions_falcon(
def extract_questions_offline(
text: str,
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
loaded_model: Union[GPT4All, None] = None,
conversation_log={},
use_history: bool = False,
run_extraction: bool = False,
use_history: bool = True,
should_extract_questions: bool = True,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
all_questions = text.split("? ")
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
if not run_extraction:
if not should_extract_questions:
return all_questions
gpt4all_model = loaded_model or GPT4All(model)
@@ -38,51 +37,85 @@ def extract_questions_falcon(
chat_history = ""
if use_history:
chat_history = "".join(
[
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj"
]
)
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
prompt = prompts.extract_questions_falcon.format(
chat_history=chat_history,
text=text,
current_date = datetime.now().strftime("%Y-%m-%d")
last_year = datetime.now().year - 1
last_christmas_date = f"{last_year}-12-25"
next_christmas_date = f"{datetime.now().year}-12-25"
system_prompt = prompts.extract_questions_system_prompt_llamav2.format(
message=(prompts.system_prompt_message_extract_questions_llamav2)
)
message = prompts.general_conversation_falcon.format(query=prompt)
response = gpt4all_model.generate(message, max_tokens=200, top_k=2)
example_questions = prompts.extract_questions_llamav2_sample.format(
query=text,
chat_history=chat_history,
current_date=current_date,
last_year=last_year,
last_christmas_date=last_christmas_date,
next_christmas_date=next_christmas_date,
)
message = system_prompt + example_questions
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0)
# Extract, Clean Message from GPT's Response
try:
# This will expect to be a list with a single string with a list of questions
questions = (
str(response)
.strip(empty_escape_sequences)
.replace("['", '["')
.replace("<s>", "")
.replace("</s>", "")
.replace("']", '"]')
.replace("', '", '", "')
.replace('["', "")
.replace('"]', "")
.split('", "')
.split("? ")
)
questions = [q + "?" for q in questions[:-1]] + [questions[-1]]
questions = filter_questions(questions)
except:
logger.warning(f"Falcon returned invalid JSON. Falling back to using user message as search query.\n{response}")
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
return all_questions
logger.debug(f"Extracted Questions by Falcon: {questions}")
logger.debug(f"Extracted Questions by Llama: {questions}")
questions.extend(all_questions)
return questions
def converse_falcon(
def filter_questions(questions: List[str]):
# Skip questions that seem to be apologizing for not being able to answer the question
hint_words = [
"sorry",
"apologize",
"unable",
"can't",
"cannot",
"don't know",
"don't understand",
"do not know",
"do not understand",
]
filtered_questions = []
for q in questions:
if not any([word in q.lower() for word in hint_words]):
filtered_questions.append(q)
return filtered_questions
def converse_offline(
references,
user_query,
conversation_log={},
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
loaded_model: Union[GPT4All, None] = None,
completion_func=None,
) -> ThreadedGenerator:
"""
Converse with user using Falcon
Converse with user using Llama
"""
gpt4all_model = loaded_model or GPT4All(model)
# Initialize Variables
@@ -92,18 +125,18 @@ def converse_falcon(
# Get Conversation Primer appropriate to Conversation Type
# TODO If compiled_references_message is too long, we need to truncate it.
if compiled_references_message == "":
conversation_primer = prompts.conversation_falcon.format(query=user_query)
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
else:
conversation_primer = prompts.notes_conversation.format(
current_date=current_date, query=user_query, references=compiled_references_message
conversation_primer = prompts.notes_conversation_llamav2.format(
query=user_query, references=compiled_references_message
)
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
prompts.personality.format(),
prompts.system_prompt_message_llamav2,
conversation_log,
model_name="text-davinci-001", # This isn't actually the model, but this helps us get an approximate encoding to run message truncation.
model_name=model,
)
g = ThreadedGenerator(references, completion_func=completion_func)
@@ -113,24 +146,22 @@ def converse_falcon(
def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
user_message = messages[0]
system_message = messages[-1]
user_message = messages[-1]
system_message = messages[0]
conversation_history = messages[1:-1]
formatted_messages = [
prompts.chat_history_falcon_from_assistant.format(message=system_message)
prompts.chat_history_llamav2_from_assistant.format(message=message.content)
if message.role == "assistant"
else prompts.chat_history_falcon_from_user.format(message=message.content)
else prompts.chat_history_llamav2_from_user.format(message=message.content)
for message in conversation_history
]
chat_history = "".join(formatted_messages)
full_message = system_message.content + chat_history + user_message.content
prompted_message = prompts.general_conversation_falcon.format(query=full_message)
response_iterator = model.generate(
prompted_message, streaming=True, max_tokens=256, top_k=1, temp=0, repeat_penalty=2.0
)
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
for response in response_iterator:
logger.info(response)
g.send(response)

View File

@@ -0,0 +1,3 @@
model_name_to_url = {
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
}

View File

@@ -0,0 +1,33 @@
import os
import logging
import requests
from gpt4all import GPT4All
import tqdm
from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__)
def download_model(model_name):
url = model_metadata.model_name_to_url.get(model_name)
if not url:
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
return GPT4All(model_name)
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
if os.path.exists(filename):
return GPT4All(model_name)
try:
os.makedirs(os.path.dirname(filename), exist_ok=True)
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return GPT4All(model_name)
except Exception as e:
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
return None

View File

@@ -18,34 +18,54 @@ Question: {query}
""".strip()
)
general_conversation_falcon = PromptTemplate.from_template(
"""
Using your general knowledge and our past conversations as context, answer the following question.
### Instruct:
{query}
### Response:
""".strip()
)
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
Using your general knowledge and our past conversations as context, answer the following question."""
chat_history_falcon_from_user = PromptTemplate.from_template(
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant.
- When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
- Try to be as specific as possible. For example, rather than use "they" or "it", use the name of the person or thing you are referring to.
- Write the question as if you can search for the answer on the user's personal notes.
- Add as much context from the previous questions and notes as required into your search queries.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- Provide search queries as a list of questions
What follow-up questions, if any, will you need to ask to answer the user's question?
"""
system_prompt_llamav2 = PromptTemplate.from_template(
"""
### Human:
<s>[INST] <<SYS>>
{message}
""".strip()
<</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>"""
)
chat_history_falcon_from_assistant = PromptTemplate.from_template(
extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
"""
### Assistant:
<s>[INST] <<SYS>>
{message}
<</SYS>>[/INST]</s>"""
)
general_conversation_llamav2 = PromptTemplate.from_template(
"""
<s>[INST]{query}[/INST]
""".strip()
)
conversation_falcon = PromptTemplate.from_template(
chat_history_llamav2_from_user = PromptTemplate.from_template(
"""
Using our past conversations as context, answer the following question.
<s>[INST]{message}[/INST]
""".strip()
)
Question: {query}
chat_history_llamav2_from_assistant = PromptTemplate.from_template(
"""
{message}</s>
""".strip()
)
conversation_llamav2 = PromptTemplate.from_template(
"""
<s>[INST]{query}[/INST]
""".strip()
)
@@ -63,13 +83,10 @@ Question: {query}
""".strip()
)
notes_conversation_falcon = PromptTemplate.from_template(
notes_conversation_llamav2 = PromptTemplate.from_template(
"""
Using the notes and our past conversations as context, answer the following question. If the answer is not contained within the notes, say "I don't know."
Notes:
{references}
Question: {query}
""".strip()
)
@@ -109,37 +126,22 @@ Question: {user_query}
Answer (in second person):"""
)
extract_questions_falcon = PromptTemplate.from_template(
extract_questions_llamav2_sample = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
- The user will provide their questions and answers to you for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
What searches, if any, will you need to perform to answer the users question?
Q: How was my trip to Cambodia?
["How was my trip to Cambodia?"]
A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful.
Q: Who did i visit that temple with?
["Who did I visit the Angkor Wat Temple in Cambodia with?"]
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]
A: 1085 tennis balls will fit in the trunk of a Honda Civic
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
<s>[INST]<<SYS>>
Use these notes from the user's previous conversations to provide a response:
{chat_history}
Q: {text}
<</SYS>>[/INST]</s>
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
<s>[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
<s>[INST]How are you feeling today?[/INST]</s>
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
<s>[INST]{query}[/INST]
"""
)

View File

@@ -13,7 +13,7 @@ import queue
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "text-davinci-001": 910}
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
class ThreadedGenerator:
@@ -102,7 +102,10 @@ def generate_chatml_messages_with_context(
def truncate_messages(messages, max_prompt_size, model_name):
"""Truncate messages to fit within max prompt size supported by model"""
encoder = tiktoken.encoding_for_model(model_name)
try:
encoder = tiktoken.encoding_for_model(model_name)
except KeyError:
encoder = tiktoken.encoding_for_model("text-davinci-001")
tokens = sum([len(encoder.encode(message.content)) for message in messages])
while tokens > max_prompt_size and len(messages) > 1:
messages.pop()

View File

@@ -38,7 +38,7 @@ from khoj.utils.yaml import save_config_to_file_updated_state
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_falcon, converse_falcon
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
from fastapi.requests import Request
@@ -715,7 +715,9 @@ async def extract_references_and_questions(
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
else:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_falcon(q, loaded_model=loaded_model, conversation_log=meta_log)
inferred_queries = extract_questions_offline(
q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):

View File

@@ -8,7 +8,7 @@ from fastapi import HTTPException, Request
from khoj.utils import state
from khoj.utils.helpers import timer, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
logger = logging.getLogger(__name__)
@@ -111,7 +111,7 @@ def generate_chat_response(
)
else:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
chat_response = converse_falcon(
chat_response = converse_offline(
references=compiled_references,
user_query=q,
loaded_model=loaded_model,

View File

@@ -5,8 +5,7 @@ from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
from gpt4all import GPT4All
from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
import torch
@@ -79,7 +78,7 @@ class SearchModels:
@dataclass
class GPT4AllProcessorConfig:
chat_model: Optional[str] = "ggml-model-gpt4all-falcon-q4_0.bin"
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
loaded_model: Union[Any, None] = None
@@ -96,7 +95,7 @@ class ConversationProcessorConfigModel:
self.meta_log: dict = {}
if not self.openai_model and self.enable_offline_chat:
self.gpt4all_model.loaded_model = GPT4All(self.gpt4all_model.chat_model) # type: ignore
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
else:
self.gpt4all_model.loaded_model = None