From 99f902d86c1b8ddbd1be9a7f090c23175a0b20f0 Mon Sep 17 00:00:00 2001 From: mrfakename Date: Mon, 26 Feb 2024 15:53:10 -0800 Subject: [PATCH] Use cached_path for better caching --- melo/download_utils.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/melo/download_utils.py b/melo/download_utils.py index e7f4afe..5d538ef 100644 --- a/melo/download_utils.py +++ b/melo/download_utils.py @@ -1,7 +1,7 @@ import torch import os from . import utils - +from cached_path import cached_path DOWNLOAD_CKPT_URLS = { 'EN': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN/checkpoint.pth', 'EN_V2': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN_V2/checkpoint.pth', @@ -25,23 +25,11 @@ DOWNLOAD_CONFIG_URLS = { def load_or_download_config(locale): language = locale.split('-')[0].upper() assert language in DOWNLOAD_CONFIG_URLS - config_path = os.path.expanduser(f'~/.local/share/openvoice/basespeakers/{language}/config.json') - try: - return utils.get_hparams_from_file(config_path) - except: - # download - os.makedirs(os.path.dirname(config_path), exist_ok=True) - os.system(f'wget {DOWNLOAD_CONFIG_URLS[language]} -O {config_path}') + config_path = cached_path(DOWNLOAD_CONFIG_URLS[language]) return utils.get_hparams_from_file(config_path) def load_or_download_model(locale, device): language = locale.split('-')[0].upper() assert language in DOWNLOAD_CKPT_URLS - ckpt_path = os.path.expanduser(f'~/.local/share/openvoice/basespeakers/{language}/checkpoint.pth') - try: - return torch.load(ckpt_path, map_location=device) - except: - # download - os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) - os.system(f'wget {DOWNLOAD_CKPT_URLS[language]} -O {ckpt_path}') - return torch.load(ckpt_path, map_location=device) \ No newline at end of file + ckpt_path = cached_path(DOWNLOAD_CKPT_URLS[language]) + return torch.load(ckpt_path, map_location=device)