Add checksums to verify the correct model is downloaded as expected (#405)

* Add checksums to verify the correct model is downloaded as expected
- This should help debug issues related to corrupted model download
- If download fails, let the application continue
* If the model is not download as expected, add some indicators in the settings UI
* Add exc_info to error log if/when download fails for llamav2 model
* Simplify checksum checking logic, update key name in model state for web client
This commit is contained in:
sabaimran
2023-08-03 06:26:52 +00:00
committed by GitHub
parent 6aa998e047
commit 0baed742e4
4 changed files with 41 additions and 8 deletions

View File

@@ -191,8 +191,8 @@
<h3 class="card-title"> <h3 class="card-title">
Chat Chat
{% if current_config.processor and current_config.processor.conversation.openai %} {% if current_config.processor and current_config.processor.conversation.openai %}
{% if current_model_state.conversation == False %} {% if current_model_state.conversation_openai == False %}
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure."> <img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The OpenAI configuration did not work as expected.">
{% else %} {% else %}
<img id="configured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured"> <img id="configured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% endif %} {% endif %}
@@ -225,7 +225,10 @@
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat"> <img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title"> <h3 class="card-title">
Offline Chat Offline Chat
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured"> <img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %}
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
{% endif %}
</h3> </h3>
</div> </div>
<div class="card-description-row"> <div class="card-description-row">

View File

@@ -1,6 +1,8 @@
import os import os
import logging import logging
import requests import requests
import hashlib
from gpt4all import GPT4All from gpt4all import GPT4All
from tqdm import tqdm from tqdm import tqdm
@@ -8,6 +10,16 @@ from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
def get_md5_checksum(filename: str):
hash_md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def download_model(model_name: str): def download_model(model_name: str):
url = model_metadata.model_name_to_url.get(model_name) url = model_metadata.model_name_to_url.get(model_name)
@@ -33,18 +45,26 @@ def download_model(model_name: str):
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc. unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
unit_divisor=1024, # is used when unit_scale is true unit_divisor=1024, # is used when unit_scale is true
total=total_size, # the total iteration. total=total_size, # the total iteration.
desc=filename.split("/")[-1], # prefix to be displayed on progress bar. desc=model_name, # prefix to be displayed on progress bar.
) as progress_bar: ) as progress_bar:
for chunk in r.iter_content(chunk_size=8192): for chunk in r.iter_content(chunk_size=8192):
f.write(chunk) f.write(chunk)
progress_bar.update(len(chunk)) progress_bar.update(len(chunk))
# Verify the checksum
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
logger.error(
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
)
os.remove(tmp_filename)
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
# Move the tmp file to the actual file # Move the tmp file to the actual file
os.rename(tmp_filename, filename) os.rename(tmp_filename, filename)
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}") logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
return GPT4All(model_name) return GPT4All(model_name)
except Exception as e: except Exception as e:
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}") logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
# Remove the tmp file if it exists # Remove the tmp file if it exists
if os.path.exists(tmp_filename): if os.path.exists(tmp_filename):
os.remove(tmp_filename) os.remove(tmp_filename)

View File

@@ -47,7 +47,7 @@ if not state.demo:
"image": False, "image": False,
"github": False, "github": False,
"notion": False, "notion": False,
"conversation": False, "enable_offline_model": False,
} }
if state.content_index: if state.content_index:
@@ -65,7 +65,8 @@ if not state.demo:
if state.processor_config: if state.processor_config:
successfully_configured.update( successfully_configured.update(
{ {
"conversation": state.processor_config.conversation is not None, "conversation_openai": state.processor_config.conversation.openai_model is not None,
"conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
} }
) )

View File

@@ -2,6 +2,8 @@
from __future__ import annotations # to avoid quoting type hints from __future__ import annotations # to avoid quoting type hints
from enum import Enum from enum import Enum
import logging
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
@@ -10,6 +12,8 @@ from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages # External Packages
import torch import torch
logger = logging.getLogger(__name__)
# Internal Packages # Internal Packages
if TYPE_CHECKING: if TYPE_CHECKING:
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
@@ -95,7 +99,12 @@ class ConversationProcessorConfigModel:
self.meta_log: dict = {} self.meta_log: dict = {}
if self.enable_offline_chat: if self.enable_offline_chat:
try:
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model) self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
except ValueError as e:
self.gpt4all_model.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
self.enable_offline_chat = False
else: else:
self.gpt4all_model.loaded_model = None self.gpt4all_model.loaded_model = None