mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 21:29:12 +00:00
Improve Quality and Reliability of Offline Chat (#393)
# Incoming ## Major ### Fix Prompt Size Exceeded Issue - Fix issues related to prompt size, Closes #386. Use the correct tokenizer to calculate whether the input needs to be truncated or not. ### Improve Llama 2 Model Download - Use the correct download link for LlamaV2 -- should have been using the small model, but was using the medium - Add better downloading logic to retry download if it failed, Closes #379 ### Fix Segmentation Fault due to Race - Add a lock around generating chat responses from the offline model to avoid segmentation faults. Closes #367. - Add a loading symbol to the web chat UI when the model is thinking. Closes #392 ### Improve Chat Response Latency - Improve performance of offline chat by increasing batch size (via `n_batch`) to automatically engage more cores/GPU, using smaller model and fixing prompt vs response token generation numbers. Closes #363 ### Fix Fake Dialogue Continuation - Fix formatting of user query with offline chat, this was contributing to #398 - Stop Llama 2 from Creating Fake Dialogue Continuations. Closes #398 ## Minor - Improve default message for Chat window on web when it's not configured. Include hint to use offline chat. - Add null check in `perform_chat_checks` method - Add offline chat director unit tests ## Performance Analysis (Time to First Token) | | v0.10.0 | this branch | |-|-|-| | Query 1 | 52s | 28s | | Query 2 | 33s| 42s | | Query 3 | 67s| 38s|
This commit is contained in:
@@ -61,6 +61,7 @@
|
||||
// Add message by user to chat body
|
||||
renderMessage(query, "you");
|
||||
document.getElementById("chat-input").value = "";
|
||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web&stream=true`;
|
||||
@@ -76,7 +77,9 @@
|
||||
new_response.appendChild(new_response_text);
|
||||
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
new_response_text.innerHTML = "🤔";
|
||||
let loadingSpinner = document.createElement("div");
|
||||
loadingSpinner.classList.add("spinner");
|
||||
new_response_text.appendChild(loadingSpinner);
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
|
||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
||||
@@ -107,10 +110,10 @@
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
} else {
|
||||
// Display response from Khoj
|
||||
if (new_response_text.innerHTML === "🤔") {
|
||||
// Clear temporary status message
|
||||
new_response_text.innerHTML = "";
|
||||
if (new_response_text.getElementsByClassName("spinner").length > 0) {
|
||||
new_response_text.removeChild(loadingSpinner);
|
||||
}
|
||||
|
||||
new_response_text.innerHTML += chunk;
|
||||
readStream();
|
||||
}
|
||||
@@ -120,6 +123,7 @@
|
||||
});
|
||||
}
|
||||
readStream();
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -136,7 +140,7 @@
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
// If the server returns a 500 error with detail, render a setup hint.
|
||||
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
|
||||
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
|
||||
|
||||
// Disable chat input field and update placeholder text
|
||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||
@@ -269,6 +273,21 @@
|
||||
margin-left: auto;
|
||||
white-space: pre-line;
|
||||
}
|
||||
/* Spinner symbol when the chat message is loading */
|
||||
.spinner {
|
||||
border: 4px solid #f3f3f3;
|
||||
border-top: 4px solid var(--primary-inverse);
|
||||
border-radius: 50%;
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
animation: spin 2s linear infinite;
|
||||
margin: 0px 0px 0px 10px;
|
||||
display: inline-block;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
/* add left protrusion to khoj chat bubble */
|
||||
.chat-message-text.khoj:after {
|
||||
content: '';
|
||||
|
||||
28
src/khoj/migrations/migrate_offline_model.py
Normal file
28
src/khoj/migrations/migrate_offline_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
import logging
|
||||
from packaging import version
|
||||
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_offline_model(args):
|
||||
schema_version = "0.10.1"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
previous_version = raw_config.get("version")
|
||||
|
||||
if previous_version is None or version.parse(previous_version) < version.parse("0.10.1"):
|
||||
logger.info(
|
||||
f"Migrating offline model used for version {previous_version} to latest version for {args.version_no}"
|
||||
)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# If the user has downloaded the offline model, remove it from the cache.
|
||||
offline_model_path = os.path.expanduser("~/.cache/gpt4all/llama-2-7b-chat.ggmlv3.q4_K_S.bin")
|
||||
if os.path.exists(offline_model_path):
|
||||
os.remove(offline_model_path)
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
return args
|
||||
@@ -34,10 +34,9 @@ from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
def migrate_processor_conversation_schema(args):
|
||||
schema_version = "0.10.0"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
raw_config["version"] = args.version_no
|
||||
|
||||
if "processor" not in raw_config:
|
||||
return args
|
||||
if raw_config["processor"] is None:
|
||||
@@ -45,22 +44,24 @@ def migrate_processor_conversation_schema(args):
|
||||
if "conversation" not in raw_config["processor"]:
|
||||
return args
|
||||
|
||||
# Add enable_offline_chat to khoj config schema
|
||||
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
|
||||
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
|
||||
current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
|
||||
if current_openai_api_key is None and current_chat_model is None:
|
||||
return args
|
||||
|
||||
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# Add enable_offline_chat to khoj config schema
|
||||
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
|
||||
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
|
||||
|
||||
# Update conversation processor schema
|
||||
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
|
||||
raw_config["processor"]["conversation"] = {
|
||||
"openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
|
||||
"conversation-logfile": conversation_logfile,
|
||||
"enable-offline-chat": False,
|
||||
}
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
return args
|
||||
|
||||
@@ -2,11 +2,12 @@ from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
def migrate_config_to_version(args):
|
||||
schema_version = "0.9.0"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
# Add version to khoj config schema
|
||||
if "version" not in raw_config:
|
||||
raw_config["version"] = args.version_no
|
||||
raw_config["version"] = schema_version
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
# regenerate khoj index on first start of this version
|
||||
|
||||
@@ -10,6 +10,7 @@ from gpt4all import GPT4All
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils import state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,7 +59,11 @@ def extract_questions_offline(
|
||||
next_christmas_date=next_christmas_date,
|
||||
)
|
||||
message = system_prompt + example_questions
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0)
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=256)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
@@ -119,13 +124,11 @@ def converse_offline(
|
||||
"""
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
# Initialize Variables
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
# TODO If compiled_references_message is too long, we need to truncate it.
|
||||
if compiled_references_message == "":
|
||||
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
|
||||
conversation_primer = user_query
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation_llamav2.format(
|
||||
query=user_query, references=compiled_references_message
|
||||
@@ -157,11 +160,20 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
||||
for message in conversation_history
|
||||
]
|
||||
|
||||
stop_words = ["<s>"]
|
||||
chat_history = "".join(formatted_messages)
|
||||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
|
||||
for response in response_iterator:
|
||||
g.send(response)
|
||||
|
||||
state.chat_lock.acquire()
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=256)
|
||||
try:
|
||||
for response in response_iterator:
|
||||
if any(stop_word in response.strip() for stop_word in stop_words):
|
||||
logger.debug(f"Stop response as hit stop word in {response}")
|
||||
break
|
||||
g.send(response)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
model_name_to_url = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ from khoj.processor.conversation.gpt4all import model_metadata
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_model(model_name):
|
||||
def download_model(model_name: str):
|
||||
url = model_metadata.model_name_to_url.get(model_name)
|
||||
if not url:
|
||||
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
|
||||
@@ -19,13 +19,16 @@ def download_model(model_name):
|
||||
if os.path.exists(filename):
|
||||
return GPT4All(model_name)
|
||||
|
||||
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
|
||||
tmp_filename = filename + ".tmp"
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
|
||||
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with open(filename, "wb") as f, tqdm(
|
||||
with open(tmp_filename, "wb") as f, tqdm(
|
||||
unit="B", # unit string to be displayed.
|
||||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
||||
unit_divisor=1024, # is used when unit_scale is true
|
||||
@@ -35,7 +38,14 @@ def download_model(model_name):
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
# Move the tmp file to the actual file
|
||||
os.rename(tmp_filename, filename)
|
||||
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
|
||||
return GPT4All(model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
|
||||
# Remove the tmp file if it exists
|
||||
if os.path.exists(tmp_filename):
|
||||
os.remove(tmp_filename)
|
||||
return None
|
||||
|
||||
@@ -19,14 +19,13 @@ Question: {query}
|
||||
)
|
||||
|
||||
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
|
||||
Using your general knowledge and our past conversations as context, answer the following question."""
|
||||
Using your general knowledge and our past conversations as context, answer the following question.
|
||||
If you do not know the answer, say 'I don't know.'"""
|
||||
|
||||
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant.
|
||||
- When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||
- Try to be as specific as possible. For example, rather than use "they" or "it", use the name of the person or thing you are referring to.
|
||||
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||
- Write the question as if you can search for the answer on the user's personal notes.
|
||||
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
|
||||
- Add as much context from the previous questions and notes as required into your search queries.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- Provide search queries as a list of questions
|
||||
What follow-up questions, if any, will you need to ask to answer the user's question?
|
||||
"""
|
||||
@@ -129,11 +128,6 @@ Answer (in second person):"""
|
||||
extract_questions_llamav2_sample = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
|
||||
<s>[INST]<<SYS>>
|
||||
Use these notes from the user's previous conversations to provide a response:
|
||||
{chat_history}
|
||||
<</SYS>>[/INST]</s>
|
||||
|
||||
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
|
||||
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
|
||||
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
|
||||
@@ -141,6 +135,10 @@ Use these notes from the user's previous conversations to provide a response:
|
||||
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
|
||||
<s>[INST]How are you feeling today?[/INST]</s>
|
||||
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
|
||||
<s>[INST]<<SYS>>
|
||||
Use these notes from the user's previous conversations to provide a response:
|
||||
{chat_history}
|
||||
<</SYS>>[/INST]</s>
|
||||
<s>[INST]{query}[/INST]
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -7,13 +7,15 @@ import tiktoken
|
||||
|
||||
# External packages
|
||||
from langchain.schema import ChatMessage
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
# Internal Packages
|
||||
import queue
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548}
|
||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
@@ -40,6 +42,10 @@ class ThreadedGenerator:
|
||||
return item
|
||||
|
||||
def send(self, data):
|
||||
if self.response == "":
|
||||
time_to_first_response = perf_counter() - self.start_time
|
||||
logger.debug(f"First response took: {time_to_first_response:.3f} seconds")
|
||||
|
||||
self.response += data
|
||||
self.queue.put(data)
|
||||
|
||||
@@ -100,30 +106,35 @@ def generate_chatml_messages_with_context(
|
||||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(messages, max_prompt_size, model_name):
|
||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
try:
|
||||
|
||||
if "llama" in model_name:
|
||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
|
||||
else:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
encoder = tiktoken.encoding_for_model("text-davinci-001")
|
||||
|
||||
system_message = messages.pop()
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
while tokens > max_prompt_size and len(messages) > 1:
|
||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
|
||||
# Truncate last message if still over max supported prompt size by model
|
||||
if tokens > max_prompt_size:
|
||||
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
if (tokens + system_message_tokens) > max_prompt_size:
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:])
|
||||
original_question_tokens = len(encoder.encode(original_question))
|
||||
remaining_tokens = max_prompt_size - original_question_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
||||
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||
logger.debug(
|
||||
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
)
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||
|
||||
return messages
|
||||
return messages + [system_message]
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
|
||||
@@ -15,9 +15,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def perform_chat_checks():
|
||||
if state.processor_config.conversation and (
|
||||
state.processor_config.conversation.openai_model
|
||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
if (
|
||||
state.processor_config
|
||||
and state.processor_config.conversation
|
||||
and (
|
||||
state.processor_config.conversation.openai_model
|
||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
@@ -89,37 +93,36 @@ def generate_chat_response(
|
||||
chat_response = None
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
if state.processor_config.conversation.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
if state.processor_config.conversation.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
meta_log,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
meta_log,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
@@ -8,6 +8,7 @@ from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.yaml import parse_config_from_file
|
||||
from khoj.migrations.migrate_version import migrate_config_to_version
|
||||
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
|
||||
from khoj.migrations.migrate_offline_model import migrate_offline_model
|
||||
|
||||
|
||||
def cli(args=None):
|
||||
@@ -55,7 +56,7 @@ def cli(args=None):
|
||||
|
||||
|
||||
def run_migrations(args):
|
||||
migrations = [migrate_config_to_version, migrate_processor_conversation_schema]
|
||||
migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model]
|
||||
for migration in migrations:
|
||||
args = migration(args)
|
||||
return args
|
||||
|
||||
@@ -25,6 +25,7 @@ port: int = None
|
||||
cli_args: List[str] = None
|
||||
query_cache = LRU()
|
||||
config_lock = threading.Lock()
|
||||
chat_lock = threading.Lock()
|
||||
SearchType = utils_config.SearchType
|
||||
telemetry: List[Dict[str, str]] = []
|
||||
previous_query: str = None
|
||||
|
||||
Reference in New Issue
Block a user