From 7ade7b740e091210145954a8f2e73be21a756d32 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Sun, 10 Mar 2024 13:05:02 +0000 Subject: [PATCH] training code done --- README.md | 5 +- docs/training.md | 37 ++ melo/api.py | 8 +- melo/configs/config.json | 94 +++++ melo/data_utils.py | 413 ++++++++++++++++++++ melo/download_utils.py | 43 ++- melo/infer.py | 25 ++ melo/losses.py | 58 +++ melo/models.py | 26 +- melo/monotonic_align/__init__.py | 16 + melo/monotonic_align/core.py | 46 +++ melo/preprocess_text.py | 135 +++++++ melo/train.py | 635 +++++++++++++++++++++++++++++++ melo/train.sh | 19 + melo/utils.py | 16 +- requirements.txt | 4 +- 16 files changed, 1533 insertions(+), 47 deletions(-) create mode 100644 docs/training.md create mode 100644 melo/configs/config.json create mode 100644 melo/data_utils.py create mode 100644 melo/infer.py create mode 100644 melo/losses.py create mode 100644 melo/monotonic_align/__init__.py create mode 100644 melo/monotonic_align/core.py create mode 100644 melo/preprocess_text.py create mode 100644 melo/train.py create mode 100644 melo/train.sh diff --git a/README.md b/README.md index f75019a..122f586 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Some other features include: ## Usage - [Use without Installation](docs/quick_use.md) - [Install and Use Locally](docs/install.md) +- [Training on Custom Dataset](docs/training.md) The Python API and model cards can be found in [this repo](https://github.com/myshell-ai/MeloTTS/blob/main/docs/install.md#python-api) or on [HuggingFace](https://huggingface.co/myshell-ai). @@ -57,10 +58,6 @@ If you find this work useful, please consider contributing to this repo. } ``` -## TODO - -- Training code release. - ## License This library is under MIT License, which means it is free for both commercial and non-commercial use. diff --git a/docs/training.md b/docs/training.md new file mode 100644 index 0000000..2ac9aa2 --- /dev/null +++ b/docs/training.md @@ -0,0 +1,37 @@ +## Training + +Before training, please install MeloTTS in dev mode and go to the `melo` folder. +``` +pip install -e . +cd melo +``` + +### Data Preparation +To train a TTS model, we need to prepare the audio files and a metadata file. We recommend using 44100Hz audio files and the metadata file should have the following format: + +``` +path/to/audio_001.wav ||| +path/to/audio_002.wav ||| +``` +The transcribed text can be obtained by ASR model, (e.g., [whisper](https://github.com/openai/whisper)). An example metadata can be found in `data/example/metadata.list` + +We can then run the preprocessing code: +``` +python preprocess_text.py --metadata data/example/metadata.list +``` +A config file `data/example/config.json` will be generated. Feel free to edit some hyper-parameters in that config file (for example, you may decrease the batch size if you have encountered the CUDA out-of-memory issue). + +### Training +The training can be launched by: +``` +bash train.sh +``` + +We have found for some machine the training will sometimes crash due to an [issue](https://github.com/pytorch/pytorch/issues/2530) of gloo. Therefore, we add an auto-resume wrapper in the `train.sh`. + +### Inference +Simply run: +``` +python infer.py --text "" -m /path/to/checkpoint/G_.pth -o +``` + diff --git a/melo/api.py b/melo/api.py index 1f2f125..236ea8f 100644 --- a/melo/api.py +++ b/melo/api.py @@ -21,7 +21,9 @@ class TTS(nn.Module): def __init__(self, language, device='auto', - use_hf=True): + use_hf=True, + config_path=None, + ckpt_path=None): super().__init__() if device == 'auto': device = 'cpu' @@ -31,7 +33,7 @@ class TTS(nn.Module): assert torch.cuda.is_available() # config_path = - hps = load_or_download_config(language, use_hf=use_hf) + hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path) num_languages = hps.num_languages num_tones = hps.num_tones @@ -54,7 +56,7 @@ class TTS(nn.Module): self.device = device # load state_dict - checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf) + checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path) self.model.load_state_dict(checkpoint_dict['model'], strict=True) language = language.split('_')[0] diff --git a/melo/configs/config.json b/melo/configs/config.json new file mode 100644 index 0000000..f93ce66 --- /dev/null +++ b/melo/configs/config.json @@ -0,0 +1,94 @@ +{ + "train": { + "log_interval": 200, + "eval_interval": 1000, + "seed": 52, + "epochs": 10000, + "learning_rate": 0.0003, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-09, + "batch_size": 6, + "fp16_run": false, + "lr_decay": 0.999875, + "segment_size": 16384, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "skip_optimizer": true + }, + "data": { + "training_files": "", + "validation_files": "", + "max_wav_value": 32768.0, + "sampling_rate": 44100, + "filter_length": 2048, + "hop_length": 512, + "win_length": 2048, + "n_mel_channels": 128, + "mel_fmin": 0.0, + "mel_fmax": null, + "add_blank": true, + "n_speakers": 256, + "cleaned_text": true, + "spk2id": {} + }, + "model": { + "use_spk_conditioned_encoder": true, + "use_noise_scaled_mas": true, + "use_mel_posterior_encoder": false, + "use_duration_discriminator": true, + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "n_layers_trans_flow": 3, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 2, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 8, + 2, + 2 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 256 + } +} diff --git a/melo/data_utils.py b/melo/data_utils.py new file mode 100644 index 0000000..179749d --- /dev/null +++ b/melo/data_utils.py @@ -0,0 +1,413 @@ +import os +import random +import torch +import torch.utils.data +from tqdm import tqdm +from loguru import logger +import commons +from mel_processing import spectrogram_torch, mel_spectrogram_torch +from utils import load_filepaths_and_text +from utils import load_wav_to_torch_librosa as load_wav_to_torch +from text import cleaned_text_to_sequence, get_bert +import numpy as np + +"""Multi speaker version""" + + +class TextAudioSpeakerLoader(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, audiopaths_sid_text, hparams): + self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.spk_map = hparams.spk2id + self.hparams = hparams + self.disable_bert = getattr(hparams, "disable_bert", False) + + self.use_mel_spec_posterior = getattr( + hparams, "use_mel_posterior_encoder", False + ) + if self.use_mel_spec_posterior: + self.n_mel_channels = getattr(hparams, "n_mel_channels", 80) + + self.cleaned_text = getattr(hparams, "cleaned_text", False) + + self.add_blank = hparams.add_blank + self.min_text_len = getattr(hparams, "min_text_len", 1) + self.max_text_len = getattr(hparams, "max_text_len", 300) + + random.seed(1234) + random.shuffle(self.audiopaths_sid_text) + self._filter() + + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + + audiopaths_sid_text_new = [] + lengths = [] + skipped = 0 + logger.info("Init dataset...") + for item in tqdm( + self.audiopaths_sid_text + ): + try: + _id, spk, language, text, phones, tone, word2ph = item + except: + print(item) + raise + audiopath = f"{_id}" + if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len: + phones = phones.split(" ") + tone = [int(i) for i in tone.split(" ")] + word2ph = [int(i) for i in word2ph.split(" ")] + audiopaths_sid_text_new.append( + [audiopath, spk, language, text, phones, tone, word2ph] + ) + lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) + else: + skipped += 1 + logger.info(f'min: {min(lengths)}; max: {max(lengths)}' ) + logger.info( + "skipped: " + + str(skipped) + + ", total: " + + str(len(self.audiopaths_sid_text)) + ) + self.audiopaths_sid_text = audiopaths_sid_text_new + self.lengths = lengths + + def get_audio_text_speaker_pair(self, audiopath_sid_text): + # separate filename, speaker_id and text + audiopath, sid, language, text, phones, tone, word2ph = audiopath_sid_text + + bert, ja_bert, phones, tone, language = self.get_text( + text, word2ph, phones, tone, language, audiopath + ) + + spec, wav = self.get_audio(audiopath) + sid = int(getattr(self.spk_map, sid, '0')) + sid = torch.LongTensor([sid]) + return (phones, spec, wav, sid, tone, language, bert, ja_bert) + + def get_audio(self, filename): + audio_norm, sampling_rate = load_wav_to_torch(filename, self.sampling_rate) + if sampling_rate != self.sampling_rate: + raise ValueError( + "{} {} SR doesn't match target {} SR".format( + filename, sampling_rate, self.sampling_rate + ) + ) + # NOTE: normalize has been achieved by torchaudio + # audio_norm = audio / self.max_wav_value + audio_norm = audio_norm.unsqueeze(0) + spec_filename = filename.replace(".wav", ".spec.pt") + if self.use_mel_spec_posterior: + spec_filename = spec_filename.replace(".spec.pt", ".mel.pt") + try: + spec = torch.load(spec_filename) + assert False + except: + if self.use_mel_spec_posterior: + spec = mel_spectrogram_torch( + audio_norm, + self.filter_length, + self.n_mel_channels, + self.sampling_rate, + self.hop_length, + self.win_length, + self.hparams.mel_fmin, + self.hparams.mel_fmax, + center=False, + ) + else: + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename) + return spec, audio_norm + + def get_text(self, text, word2ph, phone, tone, language_str, wav_path): + phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) + if self.add_blank: + phone = commons.intersperse(phone, 0) + tone = commons.intersperse(tone, 0) + language = commons.intersperse(language, 0) + for i in range(len(word2ph)): + word2ph[i] = word2ph[i] * 2 + word2ph[0] += 1 + bert_path = wav_path.replace(".wav", ".bert.pt") + try: + bert = torch.load(bert_path) + assert bert.shape[-1] == len(phone) + except Exception as e: + print(e, wav_path, bert_path, bert.shape, len(phone)) + bert = get_bert(text, word2ph, language_str) + torch.save(bert, bert_path) + assert bert.shape[-1] == len(phone), phone + + if self.disable_bert: + bert = torch.zeros(1024, len(phone)) + ja_bert = torch.zeros(768, len(phone)) + else: + if language_str in ["ZH"]: + bert = bert + ja_bert = torch.zeros(768, len(phone)) + elif language_str in ["JP", "EN", "ZH_MIX_EN", "KR", 'SP', 'ES', 'FR', 'DE', 'RU']: + ja_bert = bert + bert = torch.zeros(1024, len(phone)) + else: + raise + bert = torch.zeros(1024, len(phone)) + ja_bert = torch.zeros(768, len(phone)) + assert bert.shape[-1] == len(phone) + phone = torch.LongTensor(phone) + tone = torch.LongTensor(tone) + language = torch.LongTensor(language) + return bert, ja_bert, phone, tone, language + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def __getitem__(self, index): + return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + + def __len__(self): + return len(self.audiopaths_sid_text) + + +class TextAudioSpeakerCollate: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text, audio and speaker identities + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized, sid] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True + ) + + max_text_len = max([len(x[0]) for x in batch]) + max_spec_len = max([x[1].size(1) for x in batch]) + max_wav_len = max([x[2].size(1) for x in batch]) + + text_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + wav_lengths = torch.LongTensor(len(batch)) + sid = torch.LongTensor(len(batch)) + + text_padded = torch.LongTensor(len(batch), max_text_len) + tone_padded = torch.LongTensor(len(batch), max_text_len) + language_padded = torch.LongTensor(len(batch), max_text_len) + bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len) + ja_bert_padded = torch.FloatTensor(len(batch), 768, max_text_len) + + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + text_padded.zero_() + tone_padded.zero_() + language_padded.zero_() + spec_padded.zero_() + wav_padded.zero_() + bert_padded.zero_() + ja_bert_padded.zero_() + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + text = row[0] + text_padded[i, : text.size(0)] = text + text_lengths[i] = text.size(0) + + spec = row[1] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wav = row[2] + wav_padded[i, :, : wav.size(1)] = wav + wav_lengths[i] = wav.size(1) + + sid[i] = row[3] + + tone = row[4] + tone_padded[i, : tone.size(0)] = tone + + language = row[5] + language_padded[i, : language.size(0)] = language + + bert = row[6] + bert_padded[i, :, : bert.size(1)] = bert + + ja_bert = row[7] + ja_bert_padded[i, :, : ja_bert.size(1)] = ja_bert + + return ( + text_padded, + text_lengths, + spec_padded, + spec_lengths, + wav_padded, + wav_lengths, + sid, + tone_padded, + language_padded, + bert_padded, + ja_bert_padded, + ) + + +class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): + """ + Maintain similar input lengths in a batch. + Length groups are specified by boundaries. + Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. + + It removes samples which are not included in the boundaries. + Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. + """ + + def __init__( + self, + dataset, + batch_size, + boundaries, + num_replicas=None, + rank=None, + shuffle=True, + ): + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + self.lengths = dataset.lengths + self.batch_size = batch_size + self.boundaries = boundaries + + self.buckets, self.num_samples_per_bucket = self._create_buckets() + self.total_size = sum(self.num_samples_per_bucket) + self.num_samples = self.total_size // self.num_replicas + print('buckets:', self.num_samples_per_bucket) + + def _create_buckets(self): + buckets = [[] for _ in range(len(self.boundaries) - 1)] + for i in range(len(self.lengths)): + length = self.lengths[i] + idx_bucket = self._bisect(length) + if idx_bucket != -1: + buckets[idx_bucket].append(i) + + try: + for i in range(len(buckets) - 1, 0, -1): + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + assert all(len(bucket) > 0 for bucket in buckets) + # When one bucket is not traversed + except Exception as e: + print("Bucket warning ", e) + for i in range(len(buckets) - 1, -1, -1): + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + + num_samples_per_bucket = [] + for i in range(len(buckets)): + len_bucket = len(buckets[i]) + total_batch_size = self.num_replicas * self.batch_size + rem = ( + total_batch_size - (len_bucket % total_batch_size) + ) % total_batch_size + num_samples_per_bucket.append(len_bucket + rem) + return buckets, num_samples_per_bucket + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + if len_bucket == 0: + continue + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ( + ids_bucket + + ids_bucket * (rem // len_bucket) + + ids_bucket[: (rem % len_bucket)] + ) + + # subsample + ids_bucket = ids_bucket[self.rank :: self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [ + bucket[idx] + for idx in ids_bucket[ + j * self.batch_size : (j + 1) * self.batch_size + ] + ] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + + def _bisect(self, x, lo=0, hi=None): + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 + + def __len__(self): + return self.num_samples // self.batch_size diff --git a/melo/download_utils.py b/melo/download_utils.py index da41592..87e8a34 100644 --- a/melo/download_utils.py +++ b/melo/download_utils.py @@ -24,6 +24,12 @@ DOWNLOAD_CONFIG_URLS = { 'KR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/KR/config.json', } +PRETRAINED_MODELS = { + 'G.pth': 'https://cloud.tsinghua.edu.cn/f/91346812c11746e1b67b/?dl=1', + 'D.pth': 'https://cloud.tsinghua.edu.cn/f/4734a5281219424199f1/?dl=1', + 'DUR.pth': 'https://cloud.tsinghua.edu.cn/f/750feac7585f49ce96d7/?dl=1', +} + LANG_TO_HF_REPO_ID = { 'EN': 'myshell-ai/MeloTTS-English', 'EN_V2': 'myshell-ai/MeloTTS-English-v2', @@ -34,22 +40,27 @@ LANG_TO_HF_REPO_ID = { 'KR': 'myshell-ai/MeloTTS-Korean', } -def load_or_download_config(locale, use_hf=True): - language = locale.split('-')[0].upper() - if use_hf: - assert language in LANG_TO_HF_REPO_ID - config_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="config.json") - else: - assert language in DOWNLOAD_CONFIG_URLS - config_path = cached_path(DOWNLOAD_CONFIG_URLS[language]) +def load_or_download_config(locale, use_hf=True, config_path=None): + if config_path is None: + language = locale.split('-')[0].upper() + if use_hf: + assert language in LANG_TO_HF_REPO_ID + config_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="config.json") + else: + assert language in DOWNLOAD_CONFIG_URLS + config_path = cached_path(DOWNLOAD_CONFIG_URLS[language]) return utils.get_hparams_from_file(config_path) -def load_or_download_model(locale, device, use_hf=True): - language = locale.split('-')[0].upper() - if use_hf: - assert language in LANG_TO_HF_REPO_ID - ckpt_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="checkpoint.pth") - else: - assert language in DOWNLOAD_CKPT_URLS - ckpt_path = cached_path(DOWNLOAD_CKPT_URLS[language]) +def load_or_download_model(locale, device, use_hf=True, ckpt_path=None): + if ckpt_path is None: + language = locale.split('-')[0].upper() + if use_hf: + assert language in LANG_TO_HF_REPO_ID + ckpt_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="checkpoint.pth") + else: + assert language in DOWNLOAD_CKPT_URLS + ckpt_path = cached_path(DOWNLOAD_CKPT_URLS[language]) return torch.load(ckpt_path, map_location=device) + +def load_pretrain_model(): + return [cached_path(url) for url in PRETRAINED_MODELS.values()] \ No newline at end of file diff --git a/melo/infer.py b/melo/infer.py new file mode 100644 index 0000000..7ac1de9 --- /dev/null +++ b/melo/infer.py @@ -0,0 +1,25 @@ +import os +import click +from melo.api import TTS + + + +@click.command() +@click.option('--ckpt_path', '-m', type=str, default=None, help="Path to the checkpoint file") +@click.option('--text', '-t', type=str, default=None, help="Text to speak") +@click.option('--language', '-l', type=str, default="EN", help="Language of the model") +@click.option('--output_dir', '-o', type=str, default="outputs", help="Path to the output") +def main(ckpt_path, text, language, output_dir): + if ckpt_path is None: + raise ValueError("The model_path must be specified") + + config_path = os.path.join(os.path.dirname(ckpt_path), 'config.json') + model = TTS(language=language, config_path=config_path, ckpt_path=ckpt_path) + + for spk_name, spk_id in model.hps.data.spk2id.items(): + save_path = f'{output_dir}/{spk_name}/output.wav' + os.makedirs(os.path.dirname(save_path), exist_ok=True) + model.tts_to_file(text, spk_id, save_path) + +if __name__ == "__main__": + main() diff --git a/melo/losses.py b/melo/losses.py new file mode 100644 index 0000000..b1b263e --- /dev/null +++ b/melo/losses.py @@ -0,0 +1,58 @@ +import torch + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/melo/models.py b/melo/models.py index 8079142..46c5853 100644 --- a/melo/models.py +++ b/melo/models.py @@ -3,14 +3,15 @@ import torch from torch import nn from torch.nn import functional as F -from . import commons -from . import modules -from . import attentions +from melo import commons +from melo import modules +from melo import attentions from torch.nn import Conv1d, ConvTranspose1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from .commons import init_weights, get_padding +from melo.commons import init_weights, get_padding +import melo.monotonic_align as monotonic_align class DurationDiscriminator(nn.Module): # vits2 @@ -782,7 +783,6 @@ class SynthesizerTrn(nn.Module): num_languages=None, num_tones=None, norm_refenc=False, - use_se=False, **kwargs ): super().__init__() @@ -878,16 +878,12 @@ class SynthesizerTrn(nn.Module): hidden_channels, 256, 3, 0.5, gin_channels=gin_channels ) - if n_speakers > 1: - if use_se: - emb_dim = 512 - self.emb_g = nn.Linear(emb_dim, gin_channels) - else: - self.emb_g = nn.Embedding(n_speakers, gin_channels) + if n_speakers > 0: + self.emb_g = nn.Embedding(n_speakers, gin_channels) else: self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc) self.use_vc = use_vc - self.use_se = use_se + def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert): if self.n_speakers > 0: @@ -1024,11 +1020,7 @@ class SynthesizerTrn(nn.Module): # print('max/min of o:', o.max(), o.min()) return o, attn, y_mask, (z, z_p, m_p, logs_p) - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): - if self.use_se: - sid_src = self.emb_g(sid_src).unsqueeze(-1) - sid_tgt = self.emb_g(sid_tgt).unsqueeze(-1) - + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): g_src = sid_src g_tgt = sid_tgt z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau) diff --git a/melo/monotonic_align/__init__.py b/melo/monotonic_align/__init__.py new file mode 100644 index 0000000..15d8e60 --- /dev/null +++ b/melo/monotonic_align/__init__.py @@ -0,0 +1,16 @@ +from numpy import zeros, int32, float32 +from torch import from_numpy + +from .core import maximum_path_jit + + +def maximum_path(neg_cent, mask): + device = neg_cent.device + dtype = neg_cent.dtype + neg_cent = neg_cent.data.cpu().numpy().astype(float32) + path = zeros(neg_cent.shape, dtype=int32) + + t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) + t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) + maximum_path_jit(path, neg_cent, t_t_max, t_s_max) + return from_numpy(path).to(device=device, dtype=dtype) diff --git a/melo/monotonic_align/core.py b/melo/monotonic_align/core.py new file mode 100644 index 0000000..ffa489d --- /dev/null +++ b/melo/monotonic_align/core.py @@ -0,0 +1,46 @@ +import numba + + +@numba.jit( + numba.void( + numba.int32[:, :, ::1], + numba.float32[:, :, ::1], + numba.int32[::1], + numba.int32[::1], + ), + nopython=True, + nogil=True, +) +def maximum_path_jit(paths, values, t_ys, t_xs): + b = paths.shape[0] + max_neg_val = -1e9 + for i in range(int(b)): + path = paths[i] + value = values[i] + t_y = t_ys[i] + t_x = t_xs[i] + + v_prev = v_cur = 0.0 + index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and ( + index == y or value[y - 1, index] < value[y - 1, index - 1] + ): + index = index - 1 diff --git a/melo/preprocess_text.py b/melo/preprocess_text.py new file mode 100644 index 0000000..19dfb7c --- /dev/null +++ b/melo/preprocess_text.py @@ -0,0 +1,135 @@ +import json +from collections import defaultdict +from random import shuffle +from typing import Optional + +from tqdm import tqdm +import click +from text.cleaner import clean_text_bert +import os +import torch +from text.symbols import symbols, num_languages, num_tones + +@click.command() +@click.option( + "--metadata", + default="data/example/metadata.list", + type=click.Path(exists=True, file_okay=True, dir_okay=False), +) +@click.option("--cleaned-path", default=None) +@click.option("--train-path", default=None) +@click.option("--val-path", default=None) +@click.option( + "--config_path", + default="configs/config.json", + type=click.Path(exists=True, file_okay=True, dir_okay=False), +) +@click.option("--val-per-spk", default=4) +@click.option("--max-val-total", default=8) +@click.option("--clean/--no-clean", default=True) +def main( + metadata: str, + cleaned_path: Optional[str], + train_path: str, + val_path: str, + config_path: str, + val_per_spk: int, + max_val_total: int, + clean: bool, +): + if train_path is None: + train_path = os.path.join(os.path.dirname(metadata), 'train.list') + if val_path is None: + val_path = os.path.join(os.path.dirname(metadata), 'val.list') + out_config_path = os.path.join(os.path.dirname(metadata), 'config.json') + + if cleaned_path is None: + cleaned_path = metadata + ".cleaned" + + if clean: + out_file = open(cleaned_path, "w", encoding="utf-8") + new_symbols = [] + for line in tqdm(open(metadata, encoding="utf-8").readlines()): + try: + utt, spk, language, text = line.strip().split("|") + norm_text, phones, tones, word2ph, bert = clean_text_bert(text, language, device='cuda:0') + for ph in phones: + if ph not in symbols and ph not in new_symbols: + new_symbols.append(ph) + print('update!, now symbols:') + print(new_symbols) + with open(f'{language}_symbol.txt', 'w') as f: + f.write(f'{new_symbols}') + + assert len(phones) == len(tones) + assert len(phones) == sum(word2ph) + out_file.write( + "{}|{}|{}|{}|{}|{}|{}\n".format( + utt, + spk, + language, + norm_text, + " ".join(phones), + " ".join([str(i) for i in tones]), + " ".join([str(i) for i in word2ph]), + ) + ) + bert_path = utt.replace(".wav", ".bert.pt") + os.makedirs(os.path.dirname(bert_path), exist_ok=True) + torch.save(bert.cpu(), bert_path) + except Exception as error: + print("err!", line, error) + + out_file.close() + + metadata = cleaned_path + + spk_utt_map = defaultdict(list) + spk_id_map = {} + current_sid = 0 + + with open(metadata, encoding="utf-8") as f: + for line in f.readlines(): + utt, spk, language, text, phones, tones, word2ph = line.strip().split("|") + spk_utt_map[spk].append(line) + + if spk not in spk_id_map.keys(): + spk_id_map[spk] = current_sid + current_sid += 1 + + train_list = [] + val_list = [] + + for spk, utts in spk_utt_map.items(): + shuffle(utts) + val_list += utts[:val_per_spk] + train_list += utts[val_per_spk:] + + if len(val_list) > max_val_total: + train_list += val_list[max_val_total:] + val_list = val_list[:max_val_total] + + with open(train_path, "w", encoding="utf-8") as f: + for line in train_list: + f.write(line) + + with open(val_path, "w", encoding="utf-8") as f: + for line in val_list: + f.write(line) + + config = json.load(open(config_path, encoding="utf-8")) + config["data"]["spk2id"] = spk_id_map + + config["data"]["training_files"] = train_path + config["data"]["validation_files"] = val_path + config["data"]["n_speakers"] = len(spk_id_map) + config["num_languages"] = num_languages + config["num_tones"] = num_tones + config["symbols"] = symbols + + with open(out_config_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + main() diff --git a/melo/train.py b/melo/train.py new file mode 100644 index 0000000..88b4fe6 --- /dev/null +++ b/melo/train.py @@ -0,0 +1,635 @@ +# flake8: noqa: E402 + +import os +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.cuda.amp import autocast, GradScaler +from tqdm import tqdm +import logging + +logging.getLogger("numba").setLevel(logging.WARNING) +import commons +import utils +from data_utils import ( + TextAudioSpeakerLoader, + TextAudioSpeakerCollate, + DistributedBucketSampler, +) +from models import ( + SynthesizerTrn, + MultiPeriodDiscriminator, + DurationDiscriminator, +) +from losses import generator_loss, discriminator_loss, feature_loss, kl_loss +from mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from text.symbols import symbols +from melo.download_utils import load_pretrain_model + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = ( + True # If encontered training problem,please try to disable TF32. +) +torch.set_float32_matmul_precision("medium") + + +torch.backends.cudnn.benchmark = True +torch.backends.cuda.sdp_kernel("flash") +torch.backends.cuda.enable_flash_sdp(True) +# torch.backends.cuda.enable_mem_efficient_sdp( +# True +# ) # Not available if torch version is lower than 2.0 +torch.backends.cuda.enable_math_sdp(True) +global_step = 0 + + +def run(): + hps = utils.get_hparams() + local_rank = int(os.environ["LOCAL_RANK"]) + dist.init_process_group( + backend="gloo", + init_method="env://", # Due to some training problem,we proposed to use gloo instead of nccl. + rank=local_rank, + ) # Use torchrun instead of mp.spawn + rank = dist.get_rank() + n_gpus = dist.get_world_size() + + torch.manual_seed(hps.train.seed) + torch.cuda.set_device(rank) + global global_step + if rank == 0: + logger = utils.get_logger(hps.model_dir) + logger.info(hps) + utils.check_git_hash(hps.model_dir) + writer = SummaryWriter(log_dir=hps.model_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) + train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data) + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [32, 300, 400, 500, 600, 700, 800, 900, 1000], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate() + train_loader = DataLoader( + train_dataset, + num_workers=16, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) # DataLoader config could be adjusted. + if rank == 0: + eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data) + eval_loader = DataLoader( + eval_dataset, + num_workers=0, + shuffle=False, + batch_size=1, + pin_memory=True, + drop_last=False, + collate_fn=collate_fn, + ) + if ( + "use_noise_scaled_mas" in hps.model.keys() + and hps.model.use_noise_scaled_mas is True + ): + print("Using noise scaled MAS for VITS2") + mas_noise_scale_initial = 0.01 + noise_scale_delta = 2e-6 + else: + print("Using normal MAS for VITS1") + mas_noise_scale_initial = 0.0 + noise_scale_delta = 0.0 + if ( + "use_duration_discriminator" in hps.model.keys() + and hps.model.use_duration_discriminator is True + ): + print("Using duration discriminator for VITS2") + net_dur_disc = DurationDiscriminator( + hps.model.hidden_channels, + hps.model.hidden_channels, + 3, + 0.1, + gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, + ).cuda(rank) + if ( + "use_spk_conditioned_encoder" in hps.model.keys() + and hps.model.use_spk_conditioned_encoder is True + ): + if hps.data.n_speakers == 0: + raise ValueError( + "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model" + ) + else: + print("Using normal encoder for VITS1") + + net_g = SynthesizerTrn( + len(symbols), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + mas_noise_scale_initial=mas_noise_scale_initial, + noise_scale_delta=noise_scale_delta, + **hps.model, + ).cuda(rank) + + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) + optim_g = torch.optim.AdamW( + filter(lambda p: p.requires_grad, net_g.parameters()), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + optim_d = torch.optim.AdamW( + net_d.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + if net_dur_disc is not None: + optim_dur_disc = torch.optim.AdamW( + net_dur_disc.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + else: + optim_dur_disc = None + net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) + net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) + + pretrain_G, pretrain_D, pretrain_dur = load_pretrain_model() + hps.pretrain_G = hps.pretrain_G or pretrain_G + hps.pretrain_D = hps.pretrain_D or pretrain_D + hps.pretrain_dur = hps.pretrain_dur or pretrain_dur + + if hps.pretrain_G: + utils.load_checkpoint( + hps.pretrain_G, + net_g, + None, + skip_optimizer=True + ) + if hps.pretrain_D: + utils.load_checkpoint( + hps.pretrain_D, + net_d, + None, + skip_optimizer=True + ) + + + if net_dur_disc is not None: + net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True) + if hps.pretrain_dur: + utils.load_checkpoint( + hps.pretrain_dur, + net_dur_disc, + None, + skip_optimizer=True + ) + + try: + if net_dur_disc is not None: + _, _, dur_resume_lr, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"), + net_dur_disc, + optim_dur_disc, + skip_optimizer=hps.train.skip_optimizer + if "skip_optimizer" in hps.train + else True, + ) + _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), + net_g, + optim_g, + skip_optimizer=hps.train.skip_optimizer + if "skip_optimizer" in hps.train + else True, + ) + _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), + net_d, + optim_d, + skip_optimizer=hps.train.skip_optimizer + if "skip_optimizer" in hps.train + else True, + ) + if not optim_g.param_groups[0].get("initial_lr"): + optim_g.param_groups[0]["initial_lr"] = g_resume_lr + if not optim_d.param_groups[0].get("initial_lr"): + optim_d.param_groups[0]["initial_lr"] = d_resume_lr + if not optim_dur_disc.param_groups[0].get("initial_lr"): + optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr + + epoch_str = max(epoch_str, 1) + global_step = (epoch_str - 1) * len(train_loader) + except Exception as e: + print(e) + epoch_str = 1 + global_step = 0 + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 + ) + if net_dur_disc is not None: + scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR( + optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 + ) + else: + scheduler_dur_disc = None + scaler = GradScaler(enabled=hps.train.fp16_run) + + for epoch in range(epoch_str, hps.train.epochs + 1): + try: + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d, net_dur_disc], + [optim_g, optim_d, optim_dur_disc], + [scheduler_g, scheduler_d, scheduler_dur_disc], + scaler, + [train_loader, eval_loader], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d, net_dur_disc], + [optim_g, optim_d, optim_dur_disc], + [scheduler_g, scheduler_d, scheduler_dur_disc], + scaler, + [train_loader, None], + None, + None, + ) + except Exception as e: + print(e) + torch.cuda.empty_cache() + scheduler_g.step() + scheduler_d.step() + if net_dur_disc is not None: + scheduler_dur_disc.step() + + +def train_and_evaluate( + rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers +): + net_g, net_d, net_dur_disc = nets + optim_g, optim_d, optim_dur_disc = optims + scheduler_g, scheduler_d, scheduler_dur_disc = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + net_d.train() + if net_dur_disc is not None: + net_dur_disc.train() + for batch_idx, ( + x, + x_lengths, + spec, + spec_lengths, + y, + y_lengths, + speakers, + tone, + language, + bert, + ja_bert, + ) in enumerate(tqdm(train_loader)): + if net_g.module.use_noise_scaled_mas: + current_mas_noise_scale = ( + net_g.module.mas_noise_scale_initial + - net_g.module.noise_scale_delta * global_step + ) + net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0) + x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( + rank, non_blocking=True + ) + spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( + rank, non_blocking=True + ) + y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( + rank, non_blocking=True + ) + speakers = speakers.cuda(rank, non_blocking=True) + tone = tone.cuda(rank, non_blocking=True) + language = language.cuda(rank, non_blocking=True) + bert = bert.cuda(rank, non_blocking=True) + ja_bert = ja_bert.cuda(rank, non_blocking=True) + + with autocast(enabled=hps.train.fp16_run): + ( + y_hat, + l_length, + attn, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + (hidden_x, logw, logw_), + ) = net_g( + x, + x_lengths, + spec, + spec_lengths, + speakers, + tone, + language, + bert, + ja_bert, + ) + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_mel = commons.slice_segments( + mel, ids_slice, hps.train.segment_size // hps.data.hop_length + ) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + + y = commons.slice_segments( + y, ids_slice * hps.data.hop_length, hps.train.segment_size + ) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g + ) + loss_disc_all = loss_disc + if net_dur_disc is not None: + y_dur_hat_r, y_dur_hat_g = net_dur_disc( + hidden_x.detach(), x_mask.detach(), logw.detach(), logw_.detach() + ) + with autocast(enabled=False): + # TODO: I think need to mean using the mask, but for now, just mean all + ( + loss_dur_disc, + losses_dur_disc_r, + losses_dur_disc_g, + ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g) + loss_dur_disc_all = loss_dur_disc + optim_dur_disc.zero_grad() + scaler.scale(loss_dur_disc_all).backward() + scaler.unscale_(optim_dur_disc) + commons.clip_grad_value_(net_dur_disc.parameters(), None) + scaler.step(optim_dur_disc) + + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + with autocast(enabled=hps.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + if net_dur_disc is not None: + y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw, logw_) + with autocast(enabled=False): + loss_dur = torch.sum(l_length.float()) + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl + if net_dur_disc is not None: + loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g) + loss_gen_all += loss_dur_gen + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, 100.0 * batch_idx / len(train_loader) + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, + } + scalar_dict.update( + { + "loss/g/fm": loss_fm, + "loss/g/mel": loss_mel, + "loss/g/dur": loss_dur, + "loss/g/kl": loss_kl, + } + ) + scalar_dict.update( + {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} + ) + scalar_dict.update( + {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} + ) + scalar_dict.update( + {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} + ) + + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy() + ), + "slice/mel_gen": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy() + ), + "all/mel": utils.plot_spectrogram_to_numpy( + mel[0].data.cpu().numpy() + ), + "all/attn": utils.plot_alignment_to_numpy( + attn[0, 0].data.cpu().numpy() + ), + } + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict, + ) + + if global_step % hps.train.eval_interval == 0: + evaluate(hps, net_g, eval_loader, writer_eval) + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), + ) + if net_dur_disc is not None: + utils.save_checkpoint( + net_dur_disc, + optim_dur_disc, + hps.train.learning_rate, + epoch, + os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)), + ) + keep_ckpts = getattr(hps.train, "keep_ckpts", 5) + if keep_ckpts > 0: + utils.clean_checkpoints( + path_to_models=hps.model_dir, + n_ckpts_to_keep=keep_ckpts, + sort_by_time=True, + ) + + global_step += 1 + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + torch.cuda.empty_cache() + + +def evaluate(hps, generator, eval_loader, writer_eval): + generator.eval() + image_dict = {} + audio_dict = {} + print("Evaluating ...") + with torch.no_grad(): + for batch_idx, ( + x, + x_lengths, + spec, + spec_lengths, + y, + y_lengths, + speakers, + tone, + language, + bert, + ja_bert, + ) in enumerate(eval_loader): + x, x_lengths = x.cuda(), x_lengths.cuda() + spec, spec_lengths = spec.cuda(), spec_lengths.cuda() + y, y_lengths = y.cuda(), y_lengths.cuda() + speakers = speakers.cuda() + bert = bert.cuda() + ja_bert = ja_bert.cuda() + tone = tone.cuda() + language = language.cuda() + for use_sdp in [True, False]: + y_hat, attn, mask, *_ = generator.module.infer( + x, + x_lengths, + speakers, + tone, + language, + bert, + ja_bert, + y=spec, + max_len=1000, + sdp_ratio=0.0 if not use_sdp else 1.0, + ) + y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length + + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1).float(), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + image_dict.update( + { + f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].cpu().numpy() + ) + } + ) + audio_dict.update( + { + f"gen/audio_{batch_idx}_{use_sdp}": y_hat[ + 0, :, : y_hat_lengths[0] + ] + } + ) + image_dict.update( + { + f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( + mel[0].cpu().numpy() + ) + } + ) + audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) + + utils.summarize( + writer=writer_eval, + global_step=global_step, + images=image_dict, + audios=audio_dict, + audio_sampling_rate=hps.data.sampling_rate, + ) + generator.train() + print('Evauate done') + torch.cuda.empty_cache() + + +if __name__ == "__main__": + run() diff --git a/melo/train.sh b/melo/train.sh new file mode 100644 index 0000000..fd9e9ff --- /dev/null +++ b/melo/train.sh @@ -0,0 +1,19 @@ +CONFIG=$1 +GPUS=$2 +MODEL_NAME=$(basename "$(dirname $CONFIG)") + +PORT=10902 + +while : # auto-resume: the code sometimes crash due to bug of gloo on some gpus +do +torchrun --nproc_per_node=$GPUS \ + --master_port=$PORT \ + train.py --c $CONFIG --model $MODEL_NAME + +for PID in $(ps -aux | grep $CONFIG | grep python | awk '{print $2}') +do + echo $PID + kill -9 $PID +done +sleep 30 +done \ No newline at end of file diff --git a/melo/utils.py b/melo/utils.py index f1198f6..bafca5a 100644 --- a/melo/utils.py +++ b/melo/utils.py @@ -9,9 +9,9 @@ from scipy.io.wavfile import read import torch import torchaudio import librosa -from .text import cleaned_text_to_sequence, get_bert -from .text.cleaner import clean_text -from . import commons +from melo.text import cleaned_text_to_sequence, get_bert +from melo.text.cleaner import clean_text +from melo import commons MATPLOTLIB_FLAG = False @@ -60,8 +60,8 @@ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None): def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") - iteration = checkpoint_dict["iteration"] - learning_rate = checkpoint_dict["learning_rate"] + iteration = checkpoint_dict.get("iteration", 0) + learning_rate = checkpoint_dict.get("learning_rate", 0.) if ( optimizer is not None and not skip_optimizer @@ -92,6 +92,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False v.shape, ) except Exception as e: + print(e) # For upgrading from the old version if "ja_bert_proj" in k: v = torch.zeros_like(v) @@ -249,7 +250,9 @@ def get_hparams(init=True): default="./configs/base.json", help="JSON file for configuration", ) - parser.add_argument('--local-rank', type=int, default=0) + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--world-size', type=int, default=1) + parser.add_argument('--port', type=int, default=10000) parser.add_argument("-m", "--model", type=str, required=True, help="Model name") parser.add_argument('--pretrain_G', type=str, default=None, help='pretrain model') @@ -280,6 +283,7 @@ def get_hparams(init=True): hparams.pretrain_G = args.pretrain_G hparams.pretrain_D = args.pretrain_D hparams.pretrain_dur = args.pretrain_dur + hparams.port = args.port return hparams diff --git a/requirements.txt b/requirements.txt index 1ec5e6c..af4cb60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,6 @@ cn2an==0.5.22 jieba==0.42.1 gradio langid==1.1.6 -tqdm \ No newline at end of file +tqdm +tensorboard==2.16.2 +loguru==0.7.2