mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Add checksums to verify the correct model is downloaded as expected (#405)
* Add checksums to verify the correct model is downloaded as expected - This should help debug issues related to corrupted model download - If download fails, let the application continue * If the model is not download as expected, add some indicators in the settings UI * Add exc_info to error log if/when download fails for llamav2 model * Simplify checksum checking logic, update key name in model state for web client
This commit is contained in:
@@ -191,8 +191,8 @@
|
|||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Chat
|
Chat
|
||||||
{% if current_config.processor and current_config.processor.conversation.openai %}
|
{% if current_config.processor and current_config.processor.conversation.openai %}
|
||||||
{% if current_model_state.conversation == False %}
|
{% if current_model_state.conversation_openai == False %}
|
||||||
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The OpenAI configuration did not work as expected.">
|
||||||
{% else %}
|
{% else %}
|
||||||
<img id="configured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
<img id="configured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
{% endif %}
|
{% endif %}
|
||||||
@@ -225,7 +225,10 @@
|
|||||||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Offline Chat
|
Offline Chat
|
||||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
|
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %}
|
||||||
|
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
|
||||||
|
{% endif %}
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from gpt4all import GPT4All
|
from gpt4all import GPT4All
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -8,6 +10,16 @@ from khoj.processor.conversation.gpt4all import model_metadata
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
|
||||||
|
|
||||||
|
|
||||||
|
def get_md5_checksum(filename: str):
|
||||||
|
hash_md5 = hashlib.md5()
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(8192), b""):
|
||||||
|
hash_md5.update(chunk)
|
||||||
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def download_model(model_name: str):
|
def download_model(model_name: str):
|
||||||
url = model_metadata.model_name_to_url.get(model_name)
|
url = model_metadata.model_name_to_url.get(model_name)
|
||||||
@@ -33,18 +45,26 @@ def download_model(model_name: str):
|
|||||||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
||||||
unit_divisor=1024, # is used when unit_scale is true
|
unit_divisor=1024, # is used when unit_scale is true
|
||||||
total=total_size, # the total iteration.
|
total=total_size, # the total iteration.
|
||||||
desc=filename.split("/")[-1], # prefix to be displayed on progress bar.
|
desc=model_name, # prefix to be displayed on progress bar.
|
||||||
) as progress_bar:
|
) as progress_bar:
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
progress_bar.update(len(chunk))
|
progress_bar.update(len(chunk))
|
||||||
|
|
||||||
|
# Verify the checksum
|
||||||
|
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
|
||||||
|
logger.error(
|
||||||
|
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
|
||||||
|
)
|
||||||
|
os.remove(tmp_filename)
|
||||||
|
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
|
||||||
|
|
||||||
# Move the tmp file to the actual file
|
# Move the tmp file to the actual file
|
||||||
os.rename(tmp_filename, filename)
|
os.rename(tmp_filename, filename)
|
||||||
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
|
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
|
||||||
return GPT4All(model_name)
|
return GPT4All(model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
|
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
|
||||||
# Remove the tmp file if it exists
|
# Remove the tmp file if it exists
|
||||||
if os.path.exists(tmp_filename):
|
if os.path.exists(tmp_filename):
|
||||||
os.remove(tmp_filename)
|
os.remove(tmp_filename)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ if not state.demo:
|
|||||||
"image": False,
|
"image": False,
|
||||||
"github": False,
|
"github": False,
|
||||||
"notion": False,
|
"notion": False,
|
||||||
"conversation": False,
|
"enable_offline_model": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.content_index:
|
if state.content_index:
|
||||||
@@ -65,7 +65,8 @@ if not state.demo:
|
|||||||
if state.processor_config:
|
if state.processor_config:
|
||||||
successfully_configured.update(
|
successfully_configured.update(
|
||||||
{
|
{
|
||||||
"conversation": state.processor_config.conversation is not None,
|
"conversation_openai": state.processor_config.conversation.openai_model is not None,
|
||||||
|
"conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
from __future__ import annotations # to avoid quoting type hints
|
from __future__ import annotations # to avoid quoting type hints
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
|
||||||
@@ -10,6 +12,8 @@ from khoj.processor.conversation.gpt4all.utils import download_model
|
|||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
@@ -95,7 +99,12 @@ class ConversationProcessorConfigModel:
|
|||||||
self.meta_log: dict = {}
|
self.meta_log: dict = {}
|
||||||
|
|
||||||
if self.enable_offline_chat:
|
if self.enable_offline_chat:
|
||||||
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
|
try:
|
||||||
|
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
|
||||||
|
except ValueError as e:
|
||||||
|
self.gpt4all_model.loaded_model = None
|
||||||
|
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||||
|
self.enable_offline_chat = False
|
||||||
else:
|
else:
|
||||||
self.gpt4all_model.loaded_model = None
|
self.gpt4all_model.loaded_model = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user