From 0baed742e4854b0a180d88e099651f026c4eee83 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Thu, 3 Aug 2023 06:26:52 +0000 Subject: [PATCH] Add checksums to verify the correct model is downloaded as expected (#405) * Add checksums to verify the correct model is downloaded as expected - This should help debug issues related to corrupted model download - If download fails, let the application continue * If the model is not download as expected, add some indicators in the settings UI * Add exc_info to error log if/when download fails for llamav2 model * Simplify checksum checking logic, update key name in model state for web client --- src/khoj/interface/web/config.html | 9 ++++--- .../processor/conversation/gpt4all/utils.py | 24 +++++++++++++++++-- src/khoj/routers/web_client.py | 5 ++-- src/khoj/utils/config.py | 11 ++++++++- 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index bd994232..9763f0da 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -191,8 +191,8 @@

Chat {% if current_config.processor and current_config.processor.conversation.openai %} - {% if current_model_state.conversation == False %} - Not Configured + {% if current_model_state.conversation_openai == False %} + Not Configured {% else %} Configured {% endif %} @@ -225,7 +225,10 @@ Chat

Offline Chat - Configured + Configured + {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %} + Not Configured + {% endif %}

diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index a712d87e..25539c02 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -1,6 +1,8 @@ import os import logging import requests +import hashlib + from gpt4all import GPT4All from tqdm import tqdm @@ -8,6 +10,16 @@ from khoj.processor.conversation.gpt4all import model_metadata logger = logging.getLogger(__name__) +expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"} + + +def get_md5_checksum(filename: str): + hash_md5 = hashlib.md5() + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + def download_model(model_name: str): url = model_metadata.model_name_to_url.get(model_name) @@ -33,18 +45,26 @@ def download_model(model_name: str): unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc. unit_divisor=1024, # is used when unit_scale is true total=total_size, # the total iteration. - desc=filename.split("/")[-1], # prefix to be displayed on progress bar. + desc=model_name, # prefix to be displayed on progress bar. ) as progress_bar: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) progress_bar.update(len(chunk)) + # Verify the checksum + if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename): + logger.error( + f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available." + ) + os.remove(tmp_filename) + raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.") + # Move the tmp file to the actual file os.rename(tmp_filename, filename) logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}") return GPT4All(model_name) except Exception as e: - logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}") + logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True) # Remove the tmp file if it exists if os.path.exists(tmp_filename): os.remove(tmp_filename) diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 663b8675..9b199050 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -47,7 +47,7 @@ if not state.demo: "image": False, "github": False, "notion": False, - "conversation": False, + "enable_offline_model": False, } if state.content_index: @@ -65,7 +65,8 @@ if not state.demo: if state.processor_config: successfully_configured.update( { - "conversation": state.processor_config.conversation is not None, + "conversation_openai": state.processor_config.conversation.openai_model is not None, + "conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None, } ) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 7882edcf..4e254bee 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -2,6 +2,8 @@ from __future__ import annotations # to avoid quoting type hints from enum import Enum +import logging + from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any @@ -10,6 +12,8 @@ from khoj.processor.conversation.gpt4all.utils import download_model # External Packages import torch +logger = logging.getLogger(__name__) + # Internal Packages if TYPE_CHECKING: from sentence_transformers import CrossEncoder @@ -95,7 +99,12 @@ class ConversationProcessorConfigModel: self.meta_log: dict = {} if self.enable_offline_chat: - self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model) + try: + self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model) + except ValueError as e: + self.gpt4all_model.loaded_model = None + logger.error(f"Error while loading offline chat model: {e}", exc_info=True) + self.enable_offline_chat = False else: self.gpt4all_model.loaded_model = None