From 2335f11b006c26404eaf54e57f8431873ce141db Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 31 Jul 2023 21:07:38 -0700 Subject: [PATCH] Add better error handling for download processes incase of failure --- .../processor/conversation/gpt4all/utils.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index a6a5901a..86bfc0f5 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -9,7 +9,7 @@ from khoj.processor.conversation.gpt4all import model_metadata logger = logging.getLogger(__name__) -def download_model(model_name): +def download_model(model_name: str): url = model_metadata.model_name_to_url.get(model_name) if not url: logger.debug(f"Model {model_name} not found in model metadata. Skipping download.") @@ -19,13 +19,15 @@ def download_model(model_name): if os.path.exists(filename): return GPT4All(model_name) + tmp_filename = filename + ".tmp" + try: - os.makedirs(os.path.dirname(filename), exist_ok=True) - logger.debug(f"Downloading model {model_name} from {url} to {filename}...") + os.makedirs(os.path.dirname(tmp_filename), exist_ok=True) + logger.debug(f"Downloading model {model_name} from {url} to {tmp_filename}...") with requests.get(url, stream=True) as r: r.raise_for_status() total_size = int(r.headers.get("content-length", 0)) - with open(filename, "wb") as f, tqdm( + with open(tmp_filename, "wb") as f, tqdm( unit="B", # unit string to be displayed. unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc. unit_divisor=1024, # is used when unit_scale is true @@ -35,7 +37,14 @@ def download_model(model_name): for chunk in r.iter_content(chunk_size=8192): f.write(chunk) progress_bar.update(len(chunk)) + + logger.debug(f"Successfully downloaded model {model_name} from {url} to {tmp_filename}") + # Move the tmp file to the actual file + os.rename(tmp_filename, filename) return GPT4All(model_name) except Exception as e: - logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}") + logger.error(f"Failed to download model {model_name} from {url} to {tmp_filename}. Error: {e}") + # Remove the tmp file if it exists + if os.path.exists(tmp_filename): + os.remove(tmp_filename) return None