Files
khoj/src/khoj/processor/conversation/gpt4all/chat_model.py
sabaimran fee99779bf Add subqueries for internet-connected search results and update client-side code accordingly
- Add a wrapper method to help make direct queries to the LLM and determine any intermediate responses needed for handling the request
2023-11-20 15:19:15 -08:00

225 lines
8.3 KiB
Python

from typing import Iterator, Union, List, Any
from datetime import datetime
import logging
from threading import Thread
from langchain.schema import ChatMessage
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
from khoj.processor.conversation import prompts
from khoj.utils.constants import empty_escape_sequences
from khoj.utils import state
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
logger = logging.getLogger(__name__)
def extract_questions_offline(
text: str,
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
loaded_model: Union[Any, None] = None,
conversation_log={},
use_history: bool = True,
should_extract_questions: bool = True,
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
"""
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
# Assert that loaded_model is either None or of type GPT4All
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
all_questions = text.split("? ")
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
if not should_extract_questions:
return all_questions
gpt4all_model = loaded_model or GPT4All(model)
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = ""
if use_history:
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
current_date = datetime.now().strftime("%Y-%m-%d")
last_year = datetime.now().year - 1
last_christmas_date = f"{last_year}-12-25"
next_christmas_date = f"{datetime.now().year}-12-25"
system_prompt = prompts.system_prompt_extract_questions_gpt4all.format(
message=(prompts.system_prompt_message_extract_questions_gpt4all)
)
example_questions = prompts.extract_questions_gpt4all_sample.format(
query=text,
chat_history=chat_history,
current_date=current_date,
last_year=last_year,
last_christmas_date=last_christmas_date,
next_christmas_date=next_christmas_date,
)
message = system_prompt + example_questions
state.chat_lock.acquire()
try:
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512)
finally:
state.chat_lock.release()
# Extract, Clean Message from GPT's Response
try:
# This will expect to be a list with a single string with a list of questions
questions = (
str(response)
.strip(empty_escape_sequences)
.replace("['", '["')
.replace("<s>", "")
.replace("</s>", "")
.replace("']", '"]')
.replace("', '", '", "')
.replace('["', "")
.replace('"]', "")
.split("? ")
)
questions = [q + "?" for q in questions[:-1]] + [questions[-1]]
questions = filter_questions(questions)
except:
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
return all_questions
logger.debug(f"Extracted Questions by Llama: {questions}")
questions.extend(all_questions)
return questions
def filter_questions(questions: List[str]):
# Skip questions that seem to be apologizing for not being able to answer the question
hint_words = [
"sorry",
"apologize",
"unable",
"can't",
"cannot",
"don't know",
"don't understand",
"do not know",
"do not understand",
]
filtered_questions = []
for q in questions:
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
filtered_questions.append(q)
return filtered_questions
def converse_offline(
references,
online_results,
user_query,
conversation_log={},
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
"""
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
gpt4all_model = loaded_model or GPT4All(model)
# Initialize Variables
compiled_references_message = "\n\n".join({f"{item}" for item in references})
# Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message):
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
conversation_primer = user_query
else:
conversation_primer = prompts.notes_conversation_gpt4all.format(
query=user_query, references=compiled_references_message
)
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
prompts.system_prompt_message_gpt4all,
conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
)
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any):
user_message = messages[-1]
system_message = messages[0]
conversation_history = messages[1:-1]
formatted_messages = [
prompts.khoj_message_gpt4all.format(message=message.content)
if message.role == "assistant"
else prompts.user_message_gpt4all.format(message=message.content)
for message in conversation_history
]
stop_words = ["<s>"]
chat_history = "".join(formatted_messages)
templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message
state.chat_lock.acquire()
response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
try:
for response in response_iterator:
if any(stop_word in response.strip() for stop_word in stop_words):
logger.debug(f"Stop response as hit stop word in {response}")
break
g.send(response)
finally:
state.chat_lock.release()
g.close()
def send_message_to_model_offline(
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False
):
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
gpt4all_model = loaded_model or GPT4All(model)
return gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming)