add HF hub compatibility

This commit is contained in:
Yoach Lacombe
2024-02-29 15:55:41 +00:00
parent 69983af7b1
commit b3ed38400f
3 changed files with 73 additions and 9 deletions

View File

@@ -20,7 +20,8 @@ from .download_utils import load_or_download_config, load_or_download_model
class TTS(nn.Module):
def __init__(self,
language,
device='auto'):
device='auto',
use_hf=True):
super().__init__()
if device == 'auto':
device = 'cpu'
@@ -30,7 +31,7 @@ class TTS(nn.Module):
assert torch.cuda.is_available()
# config_path =
hps = load_or_download_config(language)
hps = load_or_download_config(language, use_hf=use_hf)
num_languages = hps.num_languages
num_tones = hps.num_tones
@@ -53,7 +54,7 @@ class TTS(nn.Module):
self.device = device
# load state_dict
checkpoint_dict = load_or_download_model(language, device)
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf)
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
language = language.split('_')[0]