add HF hub compatibility
This commit is contained in:
@@ -20,7 +20,8 @@ from .download_utils import load_or_download_config, load_or_download_model
|
||||
class TTS(nn.Module):
|
||||
def __init__(self,
|
||||
language,
|
||||
device='auto'):
|
||||
device='auto',
|
||||
use_hf=True):
|
||||
super().__init__()
|
||||
if device == 'auto':
|
||||
device = 'cpu'
|
||||
@@ -30,7 +31,7 @@ class TTS(nn.Module):
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
# config_path =
|
||||
hps = load_or_download_config(language)
|
||||
hps = load_or_download_config(language, use_hf=use_hf)
|
||||
|
||||
num_languages = hps.num_languages
|
||||
num_tones = hps.num_tones
|
||||
@@ -53,7 +54,7 @@ class TTS(nn.Module):
|
||||
self.device = device
|
||||
|
||||
# load state_dict
|
||||
checkpoint_dict = load_or_download_model(language, device)
|
||||
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf)
|
||||
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
|
||||
|
||||
language = language.split('_')[0]
|
||||
|
||||
Reference in New Issue
Block a user