mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Add better error handling for download processes incase of failure
This commit is contained in:
@@ -9,7 +9,7 @@ from khoj.processor.conversation.gpt4all import model_metadata
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def download_model(model_name):
|
def download_model(model_name: str):
|
||||||
url = model_metadata.model_name_to_url.get(model_name)
|
url = model_metadata.model_name_to_url.get(model_name)
|
||||||
if not url:
|
if not url:
|
||||||
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
|
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
|
||||||
@@ -19,13 +19,15 @@ def download_model(model_name):
|
|||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
return GPT4All(model_name)
|
return GPT4All(model_name)
|
||||||
|
|
||||||
|
tmp_filename = filename + ".tmp"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
|
||||||
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
|
logger.debug(f"Downloading model {model_name} from {url} to {tmp_filename}...")
|
||||||
with requests.get(url, stream=True) as r:
|
with requests.get(url, stream=True) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
total_size = int(r.headers.get("content-length", 0))
|
total_size = int(r.headers.get("content-length", 0))
|
||||||
with open(filename, "wb") as f, tqdm(
|
with open(tmp_filename, "wb") as f, tqdm(
|
||||||
unit="B", # unit string to be displayed.
|
unit="B", # unit string to be displayed.
|
||||||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
||||||
unit_divisor=1024, # is used when unit_scale is true
|
unit_divisor=1024, # is used when unit_scale is true
|
||||||
@@ -35,7 +37,14 @@ def download_model(model_name):
|
|||||||
for chunk in r.iter_content(chunk_size=8192):
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
progress_bar.update(len(chunk))
|
progress_bar.update(len(chunk))
|
||||||
|
|
||||||
|
logger.debug(f"Successfully downloaded model {model_name} from {url} to {tmp_filename}")
|
||||||
|
# Move the tmp file to the actual file
|
||||||
|
os.rename(tmp_filename, filename)
|
||||||
return GPT4All(model_name)
|
return GPT4All(model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
|
logger.error(f"Failed to download model {model_name} from {url} to {tmp_filename}. Error: {e}")
|
||||||
|
# Remove the tmp file if it exists
|
||||||
|
if os.path.exists(tmp_filename):
|
||||||
|
os.remove(tmp_filename)
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user