From ad495c616920f93875ac3088b1b33bd38f1a9325 Mon Sep 17 00:00:00 2001 From: mrfakename Date: Mon, 26 Feb 2024 15:58:53 -0800 Subject: [PATCH] Add progress bar support --- melo/api.py | 14 ++++++++++++-- requirements.txt | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/melo/api.py b/melo/api.py index 203d4a2..6e236f9 100644 --- a/melo/api.py +++ b/melo/api.py @@ -7,6 +7,7 @@ import soundfile import torchaudio import numpy as np import torch.nn as nn +from tqdm import tqdm from . import utils from . import commons @@ -71,11 +72,20 @@ class TTS(nn.Module): print(" > ===========================") return texts - def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, quiet=False, format=None): + def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,): language = self.language texts = self.split_sentences_into_pieces(text, language, quiet) audio_list = [] - for t in texts: + if pbar: + tx = pbar(texts) + else: + if position: + tx = tqdm(texts, position=position) + elif quiet: + tx = texts + else: + tx = tqdm(texts) + for t in tx: if language in ['EN', 'ZH_MIX_EN']: t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) device = self.device diff --git a/requirements.txt b/requirements.txt index 94648ae..4a85b5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ cn2an==0.5.22 jieba==0.42.1 gradio==3.48.0 langid==1.1.6 +tqdm \ No newline at end of file