From aa6846395d81bdd05feb8d17db6ce345652c06cf Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 1 Aug 2023 17:08:37 -0700 Subject: [PATCH] Fix offline model migration script to run for version < 0.10.1 - Use same batch_size in extract question actor as the chat actor - Log final location the chat model is to be stored in, instead of it's temp filename while it is being downloaded --- src/khoj/migrations/migrate_offline_model.py | 7 ++++--- src/khoj/processor/conversation/gpt4all/chat_model.py | 2 +- src/khoj/processor/conversation/gpt4all/utils.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/khoj/migrations/migrate_offline_model.py b/src/khoj/migrations/migrate_offline_model.py index 15847cbb..557c2407 100644 --- a/src/khoj/migrations/migrate_offline_model.py +++ b/src/khoj/migrations/migrate_offline_model.py @@ -1,5 +1,6 @@ import os import logging +from packaging import version from khoj.utils.yaml import load_config_from_file, save_config_to_file @@ -8,10 +9,10 @@ logger = logging.getLogger(__name__) def migrate_offline_model(args): raw_config = load_config_from_file(args.config_file) - version = raw_config.get("version") + version_no = raw_config.get("version") - if version == "0.10.0" or version == None: - logger.info(f"Migrating offline model used for version {version} to latest version for {args.version_no}") + if version_no is None or version.parse(version_no) < version.parse("0.10.1"): + logger.info(f"Migrating offline model used for version {version_no} to latest version for {args.version_no}") # If the user has downloaded the offline model, remove it from the cache. offline_model_path = os.path.expanduser("~/.cache/gpt4all/llama-2-7b-chat.ggmlv3.q4_K_S.bin") diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index e153e0eb..fa07e59f 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -61,7 +61,7 @@ def extract_questions_offline( message = system_prompt + example_questions state.chat_lock.acquire() try: - response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=128) + response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=256) finally: state.chat_lock.release() diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index b7f953e4..a712d87e 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -24,7 +24,7 @@ def download_model(model_name: str): try: os.makedirs(os.path.dirname(tmp_filename), exist_ok=True) - logger.debug(f"Downloading model {model_name} from {url} to {tmp_filename}...") + logger.debug(f"Downloading model {model_name} from {url} to {filename}...") with requests.get(url, stream=True) as r: r.raise_for_status() total_size = int(r.headers.get("content-length", 0)) @@ -39,12 +39,12 @@ def download_model(model_name: str): f.write(chunk) progress_bar.update(len(chunk)) - logger.debug(f"Successfully downloaded model {model_name} from {url} to {tmp_filename}") # Move the tmp file to the actual file os.rename(tmp_filename, filename) + logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}") return GPT4All(model_name) except Exception as e: - logger.error(f"Failed to download model {model_name} from {url} to {tmp_filename}. Error: {e}") + logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}") # Remove the tmp file if it exists if os.path.exists(tmp_filename): os.remove(tmp_filename)