From 0baed742e4854b0a180d88e099651f026c4eee83 Mon Sep 17 00:00:00 2001
From: sabaimran <65192171+sabaimran@users.noreply.github.com>
Date: Thu, 3 Aug 2023 06:26:52 +0000
Subject: [PATCH] Add checksums to verify the correct model is downloaded as
expected (#405)
* Add checksums to verify the correct model is downloaded as expected
- This should help debug issues related to corrupted model download
- If download fails, let the application continue
* If the model is not download as expected, add some indicators in the settings UI
* Add exc_info to error log if/when download fails for llamav2 model
* Simplify checksum checking logic, update key name in model state for web client
---
src/khoj/interface/web/config.html | 9 ++++---
.../processor/conversation/gpt4all/utils.py | 24 +++++++++++++++++--
src/khoj/routers/web_client.py | 5 ++--
src/khoj/utils/config.py | 11 ++++++++-
4 files changed, 41 insertions(+), 8 deletions(-)
diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html
index bd994232..9763f0da 100644
--- a/src/khoj/interface/web/config.html
+++ b/src/khoj/interface/web/config.html
@@ -191,8 +191,8 @@
Chat
{% if current_config.processor and current_config.processor.conversation.openai %}
- {% if current_model_state.conversation == False %}
-
+ {% if current_model_state.conversation_openai == False %}
+
{% else %}
{% endif %}
@@ -225,7 +225,10 @@
Offline Chat
-
+
+ {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %}
+
+ {% endif %}
diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py
index a712d87e..25539c02 100644
--- a/src/khoj/processor/conversation/gpt4all/utils.py
+++ b/src/khoj/processor/conversation/gpt4all/utils.py
@@ -1,6 +1,8 @@
import os
import logging
import requests
+import hashlib
+
from gpt4all import GPT4All
from tqdm import tqdm
@@ -8,6 +10,16 @@ from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__)
+expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
+
+
+def get_md5_checksum(filename: str):
+ hash_md5 = hashlib.md5()
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(8192), b""):
+ hash_md5.update(chunk)
+ return hash_md5.hexdigest()
+
def download_model(model_name: str):
url = model_metadata.model_name_to_url.get(model_name)
@@ -33,18 +45,26 @@ def download_model(model_name: str):
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
unit_divisor=1024, # is used when unit_scale is true
total=total_size, # the total iteration.
- desc=filename.split("/")[-1], # prefix to be displayed on progress bar.
+ desc=model_name, # prefix to be displayed on progress bar.
) as progress_bar:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
progress_bar.update(len(chunk))
+ # Verify the checksum
+ if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
+ logger.error(
+ f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
+ )
+ os.remove(tmp_filename)
+ raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
+
# Move the tmp file to the actual file
os.rename(tmp_filename, filename)
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
return GPT4All(model_name)
except Exception as e:
- logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
+ logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
# Remove the tmp file if it exists
if os.path.exists(tmp_filename):
os.remove(tmp_filename)
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index 663b8675..9b199050 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -47,7 +47,7 @@ if not state.demo:
"image": False,
"github": False,
"notion": False,
- "conversation": False,
+ "enable_offline_model": False,
}
if state.content_index:
@@ -65,7 +65,8 @@ if not state.demo:
if state.processor_config:
successfully_configured.update(
{
- "conversation": state.processor_config.conversation is not None,
+ "conversation_openai": state.processor_config.conversation.openai_model is not None,
+ "conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
}
)
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index 7882edcf..4e254bee 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -2,6 +2,8 @@
from __future__ import annotations # to avoid quoting type hints
from enum import Enum
+import logging
+
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
@@ -10,6 +12,8 @@ from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
import torch
+logger = logging.getLogger(__name__)
+
# Internal Packages
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder
@@ -95,7 +99,12 @@ class ConversationProcessorConfigModel:
self.meta_log: dict = {}
if self.enable_offline_chat:
- self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
+ try:
+ self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
+ except ValueError as e:
+ self.gpt4all_model.loaded_model = None
+ logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
+ self.enable_offline_chat = False
else:
self.gpt4all_model.loaded_model = None