training code done

This commit is contained in:
wl-zhao
2024-03-10 13:05:02 +00:00
parent c9c57a17f4
commit 7ade7b740e
16 changed files with 1533 additions and 47 deletions

View File

@@ -26,6 +26,7 @@ Some other features include:
## Usage
- [Use without Installation](docs/quick_use.md)
- [Install and Use Locally](docs/install.md)
- [Training on Custom Dataset](docs/training.md)
The Python API and model cards can be found in [this repo](https://github.com/myshell-ai/MeloTTS/blob/main/docs/install.md#python-api) or on [HuggingFace](https://huggingface.co/myshell-ai).
@@ -57,10 +58,6 @@ If you find this work useful, please consider contributing to this repo.
}
```
## TODO
- Training code release.
## License
This library is under MIT License, which means it is free for both commercial and non-commercial use.

37
docs/training.md Normal file
View File

@@ -0,0 +1,37 @@
## Training
Before training, please install MeloTTS in dev mode and go to the `melo` folder.
```
pip install -e .
cd melo
```
### Data Preparation
To train a TTS model, we need to prepare the audio files and a metadata file. We recommend using 44100Hz audio files and the metadata file should have the following format:
```
path/to/audio_001.wav |<speaker_name>|<language_code>|<text_001>
path/to/audio_002.wav |<speaker_name>|<language_code>|<text_002>
```
The transcribed text can be obtained by ASR model, (e.g., [whisper](https://github.com/openai/whisper)). An example metadata can be found in `data/example/metadata.list`
We can then run the preprocessing code:
```
python preprocess_text.py --metadata data/example/metadata.list
```
A config file `data/example/config.json` will be generated. Feel free to edit some hyper-parameters in that config file (for example, you may decrease the batch size if you have encountered the CUDA out-of-memory issue).
### Training
The training can be launched by:
```
bash train.sh <path/to/config.json> <num_of_gpus>
```
We have found for some machine the training will sometimes crash due to an [issue](https://github.com/pytorch/pytorch/issues/2530) of gloo. Therefore, we add an auto-resume wrapper in the `train.sh`.
### Inference
Simply run:
```
python infer.py --text "<some text here>" -m /path/to/checkpoint/G_<iter>.pth -o <output_dir>
```

View File

@@ -21,7 +21,9 @@ class TTS(nn.Module):
def __init__(self,
language,
device='auto',
use_hf=True):
use_hf=True,
config_path=None,
ckpt_path=None):
super().__init__()
if device == 'auto':
device = 'cpu'
@@ -31,7 +33,7 @@ class TTS(nn.Module):
assert torch.cuda.is_available()
# config_path =
hps = load_or_download_config(language, use_hf=use_hf)
hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
num_languages = hps.num_languages
num_tones = hps.num_tones
@@ -54,7 +56,7 @@ class TTS(nn.Module):
self.device = device
# load state_dict
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf)
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
language = language.split('_')[0]

94
melo/configs/config.json Normal file
View File

@@ -0,0 +1,94 @@
{
"train": {
"log_interval": 200,
"eval_interval": 1000,
"seed": 52,
"epochs": 10000,
"learning_rate": 0.0003,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 6,
"fp16_run": false,
"lr_decay": 0.999875,
"segment_size": 16384,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"skip_optimizer": true
},
"data": {
"training_files": "",
"validation_files": "",
"max_wav_value": 32768.0,
"sampling_rate": 44100,
"filter_length": 2048,
"hop_length": 512,
"win_length": 2048,
"n_mel_channels": 128,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 256,
"cleaned_text": true,
"spk2id": {}
},
"model": {
"use_spk_conditioned_encoder": true,
"use_noise_scaled_mas": true,
"use_mel_posterior_encoder": false,
"use_duration_discriminator": true,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"n_layers_trans_flow": 3,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
8,
8,
2,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
8,
2,
2
],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 256
}
}

413
melo/data_utils.py Normal file
View File

@@ -0,0 +1,413 @@
import os
import random
import torch
import torch.utils.data
from tqdm import tqdm
from loguru import logger
import commons
from mel_processing import spectrogram_torch, mel_spectrogram_torch
from utils import load_filepaths_and_text
from utils import load_wav_to_torch_librosa as load_wav_to_torch
from text import cleaned_text_to_sequence, get_bert
import numpy as np
"""Multi speaker version"""
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""
def __init__(self, audiopaths_sid_text, hparams):
self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
self.filter_length = hparams.filter_length
self.hop_length = hparams.hop_length
self.win_length = hparams.win_length
self.sampling_rate = hparams.sampling_rate
self.spk_map = hparams.spk2id
self.hparams = hparams
self.disable_bert = getattr(hparams, "disable_bert", False)
self.use_mel_spec_posterior = getattr(
hparams, "use_mel_posterior_encoder", False
)
if self.use_mel_spec_posterior:
self.n_mel_channels = getattr(hparams, "n_mel_channels", 80)
self.cleaned_text = getattr(hparams, "cleaned_text", False)
self.add_blank = hparams.add_blank
self.min_text_len = getattr(hparams, "min_text_len", 1)
self.max_text_len = getattr(hparams, "max_text_len", 300)
random.seed(1234)
random.shuffle(self.audiopaths_sid_text)
self._filter()
def _filter(self):
"""
Filter text & store spec lengths
"""
# Store spectrogram lengths for Bucketing
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
# spec_length = wav_length // hop_length
audiopaths_sid_text_new = []
lengths = []
skipped = 0
logger.info("Init dataset...")
for item in tqdm(
self.audiopaths_sid_text
):
try:
_id, spk, language, text, phones, tone, word2ph = item
except:
print(item)
raise
audiopath = f"{_id}"
if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len:
phones = phones.split(" ")
tone = [int(i) for i in tone.split(" ")]
word2ph = [int(i) for i in word2ph.split(" ")]
audiopaths_sid_text_new.append(
[audiopath, spk, language, text, phones, tone, word2ph]
)
lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
else:
skipped += 1
logger.info(f'min: {min(lengths)}; max: {max(lengths)}' )
logger.info(
"skipped: "
+ str(skipped)
+ ", total: "
+ str(len(self.audiopaths_sid_text))
)
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
def get_audio_text_speaker_pair(self, audiopath_sid_text):
# separate filename, speaker_id and text
audiopath, sid, language, text, phones, tone, word2ph = audiopath_sid_text
bert, ja_bert, phones, tone, language = self.get_text(
text, word2ph, phones, tone, language, audiopath
)
spec, wav = self.get_audio(audiopath)
sid = int(getattr(self.spk_map, sid, '0'))
sid = torch.LongTensor([sid])
return (phones, spec, wav, sid, tone, language, bert, ja_bert)
def get_audio(self, filename):
audio_norm, sampling_rate = load_wav_to_torch(filename, self.sampling_rate)
if sampling_rate != self.sampling_rate:
raise ValueError(
"{} {} SR doesn't match target {} SR".format(
filename, sampling_rate, self.sampling_rate
)
)
# NOTE: normalize has been achieved by torchaudio
# audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
spec_filename = filename.replace(".wav", ".spec.pt")
if self.use_mel_spec_posterior:
spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
try:
spec = torch.load(spec_filename)
assert False
except:
if self.use_mel_spec_posterior:
spec = mel_spectrogram_torch(
audio_norm,
self.filter_length,
self.n_mel_channels,
self.sampling_rate,
self.hop_length,
self.win_length,
self.hparams.mel_fmin,
self.hparams.mel_fmax,
center=False,
)
else:
spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename)
return spec, audio_norm
def get_text(self, text, word2ph, phone, tone, language_str, wav_path):
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
if self.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert_path = wav_path.replace(".wav", ".bert.pt")
try:
bert = torch.load(bert_path)
assert bert.shape[-1] == len(phone)
except Exception as e:
print(e, wav_path, bert_path, bert.shape, len(phone))
bert = get_bert(text, word2ph, language_str)
torch.save(bert, bert_path)
assert bert.shape[-1] == len(phone), phone
if self.disable_bert:
bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(768, len(phone))
else:
if language_str in ["ZH"]:
bert = bert
ja_bert = torch.zeros(768, len(phone))
elif language_str in ["JP", "EN", "ZH_MIX_EN", "KR", 'SP', 'ES', 'FR', 'DE', 'RU']:
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
raise
bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(768, len(phone))
assert bert.shape[-1] == len(phone)
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, ja_bert, phone, tone, language
def get_sid(self, sid):
sid = torch.LongTensor([int(sid)])
return sid
def __getitem__(self, index):
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
def __len__(self):
return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
)
max_text_len = max([len(x[0]) for x in batch])
max_spec_len = max([x[1].size(1) for x in batch])
max_wav_len = max([x[2].size(1) for x in batch])
text_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
wav_lengths = torch.LongTensor(len(batch))
sid = torch.LongTensor(len(batch))
text_padded = torch.LongTensor(len(batch), max_text_len)
tone_padded = torch.LongTensor(len(batch), max_text_len)
language_padded = torch.LongTensor(len(batch), max_text_len)
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
ja_bert_padded = torch.FloatTensor(len(batch), 768, max_text_len)
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
text_padded.zero_()
tone_padded.zero_()
language_padded.zero_()
spec_padded.zero_()
wav_padded.zero_()
bert_padded.zero_()
ja_bert_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
text = row[0]
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
spec = row[1]
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
sid[i] = row[3]
tone = row[4]
tone_padded[i, : tone.size(0)] = tone
language = row[5]
language_padded[i, : language.size(0)] = language
bert = row[6]
bert_padded[i, :, : bert.size(1)] = bert
ja_bert = row[7]
ja_bert_padded[i, :, : ja_bert.size(1)] = ja_bert
return (
text_padded,
text_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
sid,
tone_padded,
language_padded,
bert_padded,
ja_bert_padded,
)
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
"""
Maintain similar input lengths in a batch.
Length groups are specified by boundaries.
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
It removes samples which are not included in the boundaries.
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
"""
def __init__(
self,
dataset,
batch_size,
boundaries,
num_replicas=None,
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths
self.batch_size = batch_size
self.boundaries = boundaries
self.buckets, self.num_samples_per_bucket = self._create_buckets()
self.total_size = sum(self.num_samples_per_bucket)
self.num_samples = self.total_size // self.num_replicas
print('buckets:', self.num_samples_per_bucket)
def _create_buckets(self):
buckets = [[] for _ in range(len(self.boundaries) - 1)]
for i in range(len(self.lengths)):
length = self.lengths[i]
idx_bucket = self._bisect(length)
if idx_bucket != -1:
buckets[idx_bucket].append(i)
try:
for i in range(len(buckets) - 1, 0, -1):
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i + 1)
assert all(len(bucket) > 0 for bucket in buckets)
# When one bucket is not traversed
except Exception as e:
print("Bucket warning ", e)
for i in range(len(buckets) - 1, -1, -1):
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i + 1)
num_samples_per_bucket = []
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
if self.shuffle:
for bucket in self.buckets:
indices.append(torch.randperm(len(bucket), generator=g).tolist())
else:
for bucket in self.buckets:
indices.append(list(range(len(bucket))))
batches = []
for i in range(len(self.buckets)):
bucket = self.buckets[i]
len_bucket = len(bucket)
if len_bucket == 0:
continue
ids_bucket = indices[i]
num_samples_bucket = self.num_samples_per_bucket[i]
# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket
ids_bucket = (
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
# subsample
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
# batching
for j in range(len(ids_bucket) // self.batch_size):
batch = [
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
]
batches.append(batch)
if self.shuffle:
batch_ids = torch.randperm(len(batches), generator=g).tolist()
batches = [batches[i] for i in batch_ids]
self.batches = batches
assert len(self.batches) * self.batch_size == self.num_samples
return iter(self.batches)
def _bisect(self, x, lo=0, hi=None):
if hi is None:
hi = len(self.boundaries) - 1
if hi > lo:
mid = (hi + lo) // 2
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
return mid
elif x <= self.boundaries[mid]:
return self._bisect(x, lo, mid)
else:
return self._bisect(x, mid + 1, hi)
else:
return -1
def __len__(self):
return self.num_samples // self.batch_size

View File

@@ -24,6 +24,12 @@ DOWNLOAD_CONFIG_URLS = {
'KR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/KR/config.json',
}
PRETRAINED_MODELS = {
'G.pth': 'https://cloud.tsinghua.edu.cn/f/91346812c11746e1b67b/?dl=1',
'D.pth': 'https://cloud.tsinghua.edu.cn/f/4734a5281219424199f1/?dl=1',
'DUR.pth': 'https://cloud.tsinghua.edu.cn/f/750feac7585f49ce96d7/?dl=1',
}
LANG_TO_HF_REPO_ID = {
'EN': 'myshell-ai/MeloTTS-English',
'EN_V2': 'myshell-ai/MeloTTS-English-v2',
@@ -34,22 +40,27 @@ LANG_TO_HF_REPO_ID = {
'KR': 'myshell-ai/MeloTTS-Korean',
}
def load_or_download_config(locale, use_hf=True):
language = locale.split('-')[0].upper()
if use_hf:
assert language in LANG_TO_HF_REPO_ID
config_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="config.json")
else:
assert language in DOWNLOAD_CONFIG_URLS
config_path = cached_path(DOWNLOAD_CONFIG_URLS[language])
def load_or_download_config(locale, use_hf=True, config_path=None):
if config_path is None:
language = locale.split('-')[0].upper()
if use_hf:
assert language in LANG_TO_HF_REPO_ID
config_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="config.json")
else:
assert language in DOWNLOAD_CONFIG_URLS
config_path = cached_path(DOWNLOAD_CONFIG_URLS[language])
return utils.get_hparams_from_file(config_path)
def load_or_download_model(locale, device, use_hf=True):
language = locale.split('-')[0].upper()
if use_hf:
assert language in LANG_TO_HF_REPO_ID
ckpt_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="checkpoint.pth")
else:
assert language in DOWNLOAD_CKPT_URLS
ckpt_path = cached_path(DOWNLOAD_CKPT_URLS[language])
def load_or_download_model(locale, device, use_hf=True, ckpt_path=None):
if ckpt_path is None:
language = locale.split('-')[0].upper()
if use_hf:
assert language in LANG_TO_HF_REPO_ID
ckpt_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="checkpoint.pth")
else:
assert language in DOWNLOAD_CKPT_URLS
ckpt_path = cached_path(DOWNLOAD_CKPT_URLS[language])
return torch.load(ckpt_path, map_location=device)
def load_pretrain_model():
return [cached_path(url) for url in PRETRAINED_MODELS.values()]

25
melo/infer.py Normal file
View File

@@ -0,0 +1,25 @@
import os
import click
from melo.api import TTS
@click.command()
@click.option('--ckpt_path', '-m', type=str, default=None, help="Path to the checkpoint file")
@click.option('--text', '-t', type=str, default=None, help="Text to speak")
@click.option('--language', '-l', type=str, default="EN", help="Language of the model")
@click.option('--output_dir', '-o', type=str, default="outputs", help="Path to the output")
def main(ckpt_path, text, language, output_dir):
if ckpt_path is None:
raise ValueError("The model_path must be specified")
config_path = os.path.join(os.path.dirname(ckpt_path), 'config.json')
model = TTS(language=language, config_path=config_path, ckpt_path=ckpt_path)
for spk_name, spk_id in model.hps.data.spk2id.items():
save_path = f'{output_dir}/{spk_name}/output.wav'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
model.tts_to_file(text, spk_id, save_path)
if __name__ == "__main__":
main()

58
melo/losses.py Normal file
View File

@@ -0,0 +1,58 @@
import torch
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l

View File

@@ -3,14 +3,15 @@ import torch
from torch import nn
from torch.nn import functional as F
from . import commons
from . import modules
from . import attentions
from melo import commons
from melo import modules
from melo import attentions
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from .commons import init_weights, get_padding
from melo.commons import init_weights, get_padding
import melo.monotonic_align as monotonic_align
class DurationDiscriminator(nn.Module): # vits2
@@ -782,7 +783,6 @@ class SynthesizerTrn(nn.Module):
num_languages=None,
num_tones=None,
norm_refenc=False,
use_se=False,
**kwargs
):
super().__init__()
@@ -878,16 +878,12 @@ class SynthesizerTrn(nn.Module):
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
)
if n_speakers > 1:
if use_se:
emb_dim = 512
self.emb_g = nn.Linear(emb_dim, gin_channels)
else:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
if n_speakers > 0:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
else:
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
self.use_vc = use_vc
self.use_se = use_se
def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
if self.n_speakers > 0:
@@ -1024,11 +1020,7 @@ class SynthesizerTrn(nn.Module):
# print('max/min of o:', o.max(), o.min())
return o, attn, y_mask, (z, z_p, m_p, logs_p)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
if self.use_se:
sid_src = self.emb_g(sid_src).unsqueeze(-1)
sid_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
g_src = sid_src
g_tgt = sid_tgt
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)

View File

@@ -0,0 +1,16 @@
from numpy import zeros, int32, float32
from torch import from_numpy
from .core import maximum_path_jit
def maximum_path(neg_cent, mask):
device = neg_cent.device
dtype = neg_cent.dtype
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
path = zeros(neg_cent.shape, dtype=int32)
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
return from_numpy(path).to(device=device, dtype=dtype)

View File

@@ -0,0 +1,46 @@
import numba
@numba.jit(
numba.void(
numba.int32[:, :, ::1],
numba.float32[:, :, ::1],
numba.int32[::1],
numba.int32[::1],
),
nopython=True,
nogil=True,
)
def maximum_path_jit(paths, values, t_ys, t_xs):
b = paths.shape[0]
max_neg_val = -1e9
for i in range(int(b)):
path = paths[i]
value = values[i]
t_y = t_ys[i]
t_x = t_xs[i]
v_prev = v_cur = 0.0
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (
index == y or value[y - 1, index] < value[y - 1, index - 1]
):
index = index - 1

135
melo/preprocess_text.py Normal file
View File

@@ -0,0 +1,135 @@
import json
from collections import defaultdict
from random import shuffle
from typing import Optional
from tqdm import tqdm
import click
from text.cleaner import clean_text_bert
import os
import torch
from text.symbols import symbols, num_languages, num_tones
@click.command()
@click.option(
"--metadata",
default="data/example/metadata.list",
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@click.option("--cleaned-path", default=None)
@click.option("--train-path", default=None)
@click.option("--val-path", default=None)
@click.option(
"--config_path",
default="configs/config.json",
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@click.option("--val-per-spk", default=4)
@click.option("--max-val-total", default=8)
@click.option("--clean/--no-clean", default=True)
def main(
metadata: str,
cleaned_path: Optional[str],
train_path: str,
val_path: str,
config_path: str,
val_per_spk: int,
max_val_total: int,
clean: bool,
):
if train_path is None:
train_path = os.path.join(os.path.dirname(metadata), 'train.list')
if val_path is None:
val_path = os.path.join(os.path.dirname(metadata), 'val.list')
out_config_path = os.path.join(os.path.dirname(metadata), 'config.json')
if cleaned_path is None:
cleaned_path = metadata + ".cleaned"
if clean:
out_file = open(cleaned_path, "w", encoding="utf-8")
new_symbols = []
for line in tqdm(open(metadata, encoding="utf-8").readlines()):
try:
utt, spk, language, text = line.strip().split("|")
norm_text, phones, tones, word2ph, bert = clean_text_bert(text, language, device='cuda:0')
for ph in phones:
if ph not in symbols and ph not in new_symbols:
new_symbols.append(ph)
print('update!, now symbols:')
print(new_symbols)
with open(f'{language}_symbol.txt', 'w') as f:
f.write(f'{new_symbols}')
assert len(phones) == len(tones)
assert len(phones) == sum(word2ph)
out_file.write(
"{}|{}|{}|{}|{}|{}|{}\n".format(
utt,
spk,
language,
norm_text,
" ".join(phones),
" ".join([str(i) for i in tones]),
" ".join([str(i) for i in word2ph]),
)
)
bert_path = utt.replace(".wav", ".bert.pt")
os.makedirs(os.path.dirname(bert_path), exist_ok=True)
torch.save(bert.cpu(), bert_path)
except Exception as error:
print("err!", line, error)
out_file.close()
metadata = cleaned_path
spk_utt_map = defaultdict(list)
spk_id_map = {}
current_sid = 0
with open(metadata, encoding="utf-8") as f:
for line in f.readlines():
utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
spk_utt_map[spk].append(line)
if spk not in spk_id_map.keys():
spk_id_map[spk] = current_sid
current_sid += 1
train_list = []
val_list = []
for spk, utts in spk_utt_map.items():
shuffle(utts)
val_list += utts[:val_per_spk]
train_list += utts[val_per_spk:]
if len(val_list) > max_val_total:
train_list += val_list[max_val_total:]
val_list = val_list[:max_val_total]
with open(train_path, "w", encoding="utf-8") as f:
for line in train_list:
f.write(line)
with open(val_path, "w", encoding="utf-8") as f:
for line in val_list:
f.write(line)
config = json.load(open(config_path, encoding="utf-8"))
config["data"]["spk2id"] = spk_id_map
config["data"]["training_files"] = train_path
config["data"]["validation_files"] = val_path
config["data"]["n_speakers"] = len(spk_id_map)
config["num_languages"] = num_languages
config["num_tones"] = num_tones
config["symbols"] = symbols
with open(out_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
if __name__ == "__main__":
main()

635
melo/train.py Normal file
View File

@@ -0,0 +1,635 @@
# flake8: noqa: E402
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging
logging.getLogger("numba").setLevel(logging.WARNING)
import commons
import utils
from data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler,
)
from models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
DurationDiscriminator,
)
from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from text.symbols import symbols
from melo.download_utils import load_pretrain_model
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = (
True # If encontered training problem,please try to disable TF32.
)
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.sdp_kernel("flash")
torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(
# True
# ) # Not available if torch version is lower than 2.0
torch.backends.cuda.enable_math_sdp(True)
global_step = 0
def run():
hps = utils.get_hparams()
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(
backend="gloo",
init_method="env://", # Due to some training problem,we proposed to use gloo instead of nccl.
rank=local_rank,
) # Use torchrun instead of mp.spawn
rank = dist.get_rank()
n_gpus = dist.get_world_size()
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
global global_step
if rank == 0:
logger = utils.get_logger(hps.model_dir)
logger.info(hps)
utils.check_git_hash(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[32, 300, 400, 500, 600, 700, 800, 900, 1000],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
train_dataset,
num_workers=16,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=4,
) # DataLoader config could be adjusted.
if rank == 0:
eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
eval_loader = DataLoader(
eval_dataset,
num_workers=0,
shuffle=False,
batch_size=1,
pin_memory=True,
drop_last=False,
collate_fn=collate_fn,
)
if (
"use_noise_scaled_mas" in hps.model.keys()
and hps.model.use_noise_scaled_mas is True
):
print("Using noise scaled MAS for VITS2")
mas_noise_scale_initial = 0.01
noise_scale_delta = 2e-6
else:
print("Using normal MAS for VITS1")
mas_noise_scale_initial = 0.0
noise_scale_delta = 0.0
if (
"use_duration_discriminator" in hps.model.keys()
and hps.model.use_duration_discriminator is True
):
print("Using duration discriminator for VITS2")
net_dur_disc = DurationDiscriminator(
hps.model.hidden_channels,
hps.model.hidden_channels,
3,
0.1,
gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
).cuda(rank)
if (
"use_spk_conditioned_encoder" in hps.model.keys()
and hps.model.use_spk_conditioned_encoder is True
):
if hps.data.n_speakers == 0:
raise ValueError(
"n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
)
else:
print("Using normal encoder for VITS1")
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
mas_noise_scale_initial=mas_noise_scale_initial,
noise_scale_delta=noise_scale_delta,
**hps.model,
).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
optim_g = torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
optim_d = torch.optim.AdamW(
net_d.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
if net_dur_disc is not None:
optim_dur_disc = torch.optim.AdamW(
net_dur_disc.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
else:
optim_dur_disc = None
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
pretrain_G, pretrain_D, pretrain_dur = load_pretrain_model()
hps.pretrain_G = hps.pretrain_G or pretrain_G
hps.pretrain_D = hps.pretrain_D or pretrain_D
hps.pretrain_dur = hps.pretrain_dur or pretrain_dur
if hps.pretrain_G:
utils.load_checkpoint(
hps.pretrain_G,
net_g,
None,
skip_optimizer=True
)
if hps.pretrain_D:
utils.load_checkpoint(
hps.pretrain_D,
net_d,
None,
skip_optimizer=True
)
if net_dur_disc is not None:
net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
if hps.pretrain_dur:
utils.load_checkpoint(
hps.pretrain_dur,
net_dur_disc,
None,
skip_optimizer=True
)
try:
if net_dur_disc is not None:
_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
net_dur_disc,
optim_dur_disc,
skip_optimizer=hps.train.skip_optimizer
if "skip_optimizer" in hps.train
else True,
)
_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
net_g,
optim_g,
skip_optimizer=hps.train.skip_optimizer
if "skip_optimizer" in hps.train
else True,
)
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
net_d,
optim_d,
skip_optimizer=hps.train.skip_optimizer
if "skip_optimizer" in hps.train
else True,
)
if not optim_g.param_groups[0].get("initial_lr"):
optim_g.param_groups[0]["initial_lr"] = g_resume_lr
if not optim_d.param_groups[0].get("initial_lr"):
optim_d.param_groups[0]["initial_lr"] = d_resume_lr
if not optim_dur_disc.param_groups[0].get("initial_lr"):
optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
epoch_str = max(epoch_str, 1)
global_step = (epoch_str - 1) * len(train_loader)
except Exception as e:
print(e)
epoch_str = 1
global_step = 0
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
)
if net_dur_disc is not None:
scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
)
else:
scheduler_dur_disc = None
scaler = GradScaler(enabled=hps.train.fp16_run)
for epoch in range(epoch_str, hps.train.epochs + 1):
try:
if rank == 0:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d, net_dur_disc],
[optim_g, optim_d, optim_dur_disc],
[scheduler_g, scheduler_d, scheduler_dur_disc],
scaler,
[train_loader, eval_loader],
logger,
[writer, writer_eval],
)
else:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d, net_dur_disc],
[optim_g, optim_d, optim_dur_disc],
[scheduler_g, scheduler_d, scheduler_dur_disc],
scaler,
[train_loader, None],
None,
None,
)
except Exception as e:
print(e)
torch.cuda.empty_cache()
scheduler_g.step()
scheduler_d.step()
if net_dur_disc is not None:
scheduler_dur_disc.step()
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d, net_dur_disc = nets
optim_g, optim_d, optim_dur_disc = optims
scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
train_loader, eval_loader = loaders
if writers is not None:
writer, writer_eval = writers
train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
net_d.train()
if net_dur_disc is not None:
net_dur_disc.train()
for batch_idx, (
x,
x_lengths,
spec,
spec_lengths,
y,
y_lengths,
speakers,
tone,
language,
bert,
ja_bert,
) in enumerate(tqdm(train_loader)):
if net_g.module.use_noise_scaled_mas:
current_mas_noise_scale = (
net_g.module.mas_noise_scale_initial
- net_g.module.noise_scale_delta * global_step
)
net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
rank, non_blocking=True
)
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True
)
speakers = speakers.cuda(rank, non_blocking=True)
tone = tone.cuda(rank, non_blocking=True)
language = language.cuda(rank, non_blocking=True)
bert = bert.cuda(rank, non_blocking=True)
ja_bert = ja_bert.cuda(rank, non_blocking=True)
with autocast(enabled=hps.train.fp16_run):
(
y_hat,
l_length,
attn,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
(hidden_x, logw, logw_),
) = net_g(
x,
x_lengths,
spec,
spec_lengths,
speakers,
tone,
language,
bert,
ja_bert,
)
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y = commons.slice_segments(
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
loss_disc_all = loss_disc
if net_dur_disc is not None:
y_dur_hat_r, y_dur_hat_g = net_dur_disc(
hidden_x.detach(), x_mask.detach(), logw.detach(), logw_.detach()
)
with autocast(enabled=False):
# TODO: I think need to mean using the mask, but for now, just mean all
(
loss_dur_disc,
losses_dur_disc_r,
losses_dur_disc_g,
) = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
loss_dur_disc_all = loss_dur_disc
optim_dur_disc.zero_grad()
scaler.scale(loss_dur_disc_all).backward()
scaler.unscale_(optim_dur_disc)
commons.clip_grad_value_(net_dur_disc.parameters(), None)
scaler.step(optim_dur_disc)
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
scaler.unscale_(optim_d)
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
if net_dur_disc is not None:
y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw, logw_)
with autocast(enabled=False):
loss_dur = torch.sum(l_length.float())
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
if net_dur_disc is not None:
loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
loss_gen_all += loss_dur_gen
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update(
{
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/dur": loss_dur,
"loss/g/kl": loss_kl,
}
)
scalar_dict.update(
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
)
scalar_dict.update(
{"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
)
scalar_dict.update(
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
)
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
"all/attn": utils.plot_alignment_to_numpy(
attn[0, 0].data.cpu().numpy()
),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)
if global_step % hps.train.eval_interval == 0:
evaluate(hps, net_g, eval_loader, writer_eval)
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
)
if net_dur_disc is not None:
utils.save_checkpoint(
net_dur_disc,
optim_dur_disc,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)),
)
keep_ckpts = getattr(hps.train, "keep_ckpts", 5)
if keep_ckpts > 0:
utils.clean_checkpoints(
path_to_models=hps.model_dir,
n_ckpts_to_keep=keep_ckpts,
sort_by_time=True,
)
global_step += 1
if rank == 0:
logger.info("====> Epoch: {}".format(epoch))
torch.cuda.empty_cache()
def evaluate(hps, generator, eval_loader, writer_eval):
generator.eval()
image_dict = {}
audio_dict = {}
print("Evaluating ...")
with torch.no_grad():
for batch_idx, (
x,
x_lengths,
spec,
spec_lengths,
y,
y_lengths,
speakers,
tone,
language,
bert,
ja_bert,
) in enumerate(eval_loader):
x, x_lengths = x.cuda(), x_lengths.cuda()
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda()
speakers = speakers.cuda()
bert = bert.cuda()
ja_bert = ja_bert.cuda()
tone = tone.cuda()
language = language.cuda()
for use_sdp in [True, False]:
y_hat, attn, mask, *_ = generator.module.infer(
x,
x_lengths,
speakers,
tone,
language,
bert,
ja_bert,
y=spec,
max_len=1000,
sdp_ratio=0.0 if not use_sdp else 1.0,
)
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1).float(),
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
image_dict.update(
{
f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy()
)
}
)
audio_dict.update(
{
f"gen/audio_{batch_idx}_{use_sdp}": y_hat[
0, :, : y_hat_lengths[0]
]
}
)
image_dict.update(
{
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
mel[0].cpu().numpy()
)
}
)
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
utils.summarize(
writer=writer_eval,
global_step=global_step,
images=image_dict,
audios=audio_dict,
audio_sampling_rate=hps.data.sampling_rate,
)
generator.train()
print('Evauate done')
torch.cuda.empty_cache()
if __name__ == "__main__":
run()

19
melo/train.sh Normal file
View File

@@ -0,0 +1,19 @@
CONFIG=$1
GPUS=$2
MODEL_NAME=$(basename "$(dirname $CONFIG)")
PORT=10902
while : # auto-resume: the code sometimes crash due to bug of gloo on some gpus
do
torchrun --nproc_per_node=$GPUS \
--master_port=$PORT \
train.py --c $CONFIG --model $MODEL_NAME
for PID in $(ps -aux | grep $CONFIG | grep python | awk '{print $2}')
do
echo $PID
kill -9 $PID
done
sleep 30
done

View File

@@ -9,9 +9,9 @@ from scipy.io.wavfile import read
import torch
import torchaudio
import librosa
from .text import cleaned_text_to_sequence, get_bert
from .text.cleaner import clean_text
from . import commons
from melo.text import cleaned_text_to_sequence, get_bert
from melo.text.cleaner import clean_text
from melo import commons
MATPLOTLIB_FLAG = False
@@ -60,8 +60,8 @@ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
iteration = checkpoint_dict.get("iteration", 0)
learning_rate = checkpoint_dict.get("learning_rate", 0.)
if (
optimizer is not None
and not skip_optimizer
@@ -92,6 +92,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
v.shape,
)
except Exception as e:
print(e)
# For upgrading from the old version
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
@@ -249,7 +250,9 @@ def get_hparams(init=True):
default="./configs/base.json",
help="JSON file for configuration",
)
parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--world-size', type=int, default=1)
parser.add_argument('--port', type=int, default=10000)
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
parser.add_argument('--pretrain_G', type=str, default=None,
help='pretrain model')
@@ -280,6 +283,7 @@ def get_hparams(init=True):
hparams.pretrain_G = args.pretrain_G
hparams.pretrain_D = args.pretrain_D
hparams.pretrain_dur = args.pretrain_dur
hparams.port = args.port
return hparams

View File

@@ -25,4 +25,6 @@ cn2an==0.5.22
jieba==0.42.1
gradio
langid==1.1.6
tqdm
tqdm
tensorboard==2.16.2
loguru==0.7.2