|
import numpy as np |
|
import soundfile |
|
import onnxruntime as ort |
|
import argparse |
|
import time |
|
from utils import * |
|
from rknnlite.api import RKNNLite |
|
|
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser( |
|
prog="melotts", |
|
description="Run TTS on input sentence" |
|
) |
|
parser.add_argument("--sentence", "-s", type=str, required=False, default="爱芯元智半导体股份有限公司,致力于打造世界领先的人工智能感知与边缘计算芯片。服务智慧城市、智能驾驶、机器人的海量普惠的应用") |
|
parser.add_argument("--wav", "-w", type=str, required=False, default="output.wav") |
|
parser.add_argument("--encoder", "-e", type=str, required=False, default="encoder.onnx") |
|
parser.add_argument("--decoder", "-d", type=str, required=False, default="decoder.rknn") |
|
parser.add_argument("--sample_rate", "-sr", type=int, required=False, default=44100) |
|
parser.add_argument("--speed", type=float, required=False, default=0.8) |
|
parser.add_argument("--lexicon", type=str, required=False, default="lexicon.txt") |
|
parser.add_argument("--token", type=str, required=False, default="tokens.txt") |
|
return parser.parse_args() |
|
|
|
|
|
def audio_numpy_concat(segment_data_list, sr, speed=1.): |
|
audio_segments = [] |
|
for segment_data in segment_data_list: |
|
audio_segments += segment_data.reshape(-1).tolist() |
|
audio_segments += [0] * int((sr * 0.05) / speed) |
|
audio_segments = np.array(audio_segments).astype(np.float32) |
|
return audio_segments |
|
|
|
|
|
def merge_sub_audio(sub_audio_list, pad_size, audio_len): |
|
|
|
if pad_size > 0: |
|
for i in range(len(sub_audio_list) - 1): |
|
sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size] |
|
sub_audio_list[i][-pad_size:] /= 2 |
|
if i > 0: |
|
sub_audio_list[i] = sub_audio_list[i][pad_size:] |
|
|
|
sub_audio = np.concatenate(sub_audio_list, axis=-1) |
|
return sub_audio[:audio_len] |
|
|
|
|
|
def generate_pronounce_slice(yinjie_num): |
|
|
|
pron_slice = [] |
|
start = 0 |
|
end = 0 |
|
for i, n in enumerate(yinjie_num): |
|
end = start + n |
|
pron_slice.append(slice(start, end)) |
|
start = end |
|
return pron_slice |
|
|
|
|
|
def generate_word_pron_num(pron_lens, pron_slices): |
|
|
|
pron_num = [] |
|
for s in pron_slices: |
|
pron_num.append(pron_lens[s].sum()) |
|
return pron_num |
|
|
|
|
|
def decode_long_word(sess_dec, z_p, g, dec_len): |
|
z_p_len = z_p.shape[-1] |
|
slice_num = int(np.ceil(z_p_len / dec_len)) |
|
sub_audio_list = [] |
|
|
|
for i in range(slice_num): |
|
z_p_slice = z_p[..., i * dec_len : (i + 1) * dec_len] |
|
sub_dec_len = z_p_slice.shape[-1] |
|
sub_audio_len = 512 * sub_dec_len |
|
|
|
if z_p_slice.shape[-1] < dec_len: |
|
z_p_slice = np.concatenate((z_p_slice, np.zeros((*z_p_slice.shape[:-1], dec_len - z_p_slice.shape[-1]), dtype=np.float32)), axis=-1) |
|
|
|
start = time.time() |
|
audio = sess_dec.inference(inputs=[z_p_slice, g])[0].flatten() |
|
audio = audio[:sub_audio_len] |
|
print(f"Long word slice[{i}]: decoder run take {1000 * (time.time() - start):.2f}ms") |
|
sub_audio_list.append(audio) |
|
return sub_audio_list |
|
|
|
|
|
def generate_decode_slices(pron_num, dec_len): |
|
pron_slices = [] |
|
zp_slices = [] |
|
strip_flags = [] |
|
pron_lens = [] |
|
is_long = [] |
|
|
|
start = end = 0 |
|
zp_start, zp_end = 0, 0 |
|
while (end < len(pron_num)): |
|
if len(pron_lens) > 0 and pron_lens[-1] > 2 and pron_num[end] < dec_len: |
|
prev_end = pron_slices[-1][1] |
|
pad_size = pron_num[prev_end - 1] + pron_num[prev_end - 2] |
|
start = end - 2 |
|
zp_start = zp_end - pad_size |
|
strip_head, strip_tail = True, False |
|
if is_long[-1]: |
|
strip_flags[-1][1] = False |
|
else: |
|
strip_flags[-1][1] = True |
|
else: |
|
pad_size = 0 |
|
start = end |
|
zp_start = zp_end |
|
strip_head, strip_tail = False, False |
|
if len(strip_flags) > 0: |
|
strip_flags[-1][1] = False |
|
|
|
sub_dec_len = abs(zp_end - zp_start) |
|
while (end < len(pron_num) and sub_dec_len < dec_len): |
|
sub_dec_len += pron_num[end] |
|
end += 1 |
|
|
|
if end - start == 1 and sub_dec_len > dec_len: |
|
|
|
is_long.append(True) |
|
if len(strip_flags) > 0: |
|
strip_flags[-1][1] = False |
|
else: |
|
is_long.append(False) |
|
if sub_dec_len > dec_len: |
|
sub_dec_len -= pron_num[end - 1] |
|
end -= 1 |
|
|
|
zp_end = zp_start + sub_dec_len |
|
|
|
pron_slices.append([start, end]) |
|
zp_slices.append([zp_start, zp_end]) |
|
strip_flags.append([strip_head, strip_tail]) |
|
pron_lens.append(end - start) |
|
|
|
return pron_slices, zp_slices, strip_flags, pron_lens, is_long |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
sentence = args.sentence |
|
sample_rate = args.sample_rate |
|
lexicon_filename = args.lexicon |
|
token_filename = args.token |
|
enc_model = args.encoder |
|
dec_model = args.decoder |
|
|
|
print(f"sentence: {sentence}") |
|
print(f"sample_rate: {sample_rate}") |
|
print(f"lexicon: {lexicon_filename}") |
|
print(f"token: {token_filename}") |
|
print(f"encoder: {enc_model}") |
|
print(f"decoder: {dec_model}") |
|
|
|
|
|
sens = split_sentences_zh(sentence) |
|
|
|
|
|
lexicon = Lexicon(lexicon_filename, token_filename) |
|
|
|
|
|
start = time.time() |
|
sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions()) |
|
|
|
|
|
decoder = RKNNLite() |
|
ret = decoder.load_rknn(dec_model) |
|
if ret != 0: |
|
print('Load decoder RKNN model failed') |
|
exit(ret) |
|
ret = decoder.init_runtime() |
|
if ret != 0: |
|
print('Init runtime failed') |
|
exit(ret) |
|
|
|
dec_len = 65536 // 512 |
|
print(f"load models take {1000 * (time.time() - start)}ms") |
|
|
|
|
|
g = np.fromfile("g.bin", dtype=np.float32).reshape(1, 256, 1) |
|
|
|
|
|
audio_list = [] |
|
|
|
|
|
for n, se in enumerate(sens): |
|
print(f"\nSentence[{n}]: {se}") |
|
|
|
phone_str, yinjie_num, phones, tones = lexicon.convert(se) |
|
|
|
|
|
phone_str = intersperse(phone_str, 0) |
|
phones = np.array(intersperse(phones, 0), dtype=np.int32) |
|
tones = np.array(intersperse(tones, 0), dtype=np.int32) |
|
yinjie_num = np.array(yinjie_num, dtype=np.int32) * 2 |
|
yinjie_num[0] += 1 |
|
assert (yinjie_num.sum() == phones.shape[0]) |
|
pron_slices = generate_pronounce_slice(yinjie_num) |
|
|
|
|
|
|
|
|
|
phone_len = phones.shape[-1] |
|
|
|
language = np.array([3] * phone_len, dtype=np.int32) |
|
|
|
start = time.time() |
|
z_p, pronoun_lens, audio_len = sess_enc.run(None, input_feed={ |
|
'phone': phones, 'g': g, |
|
'tone': tones, 'language': language, |
|
'noise_scale': np.array([0], dtype=np.float32), |
|
'length_scale': np.array([1.0 / args.speed], dtype=np.float32), |
|
'noise_scale_w': np.array([0], dtype=np.float32), |
|
'sdp_ratio': np.array([0], dtype=np.float32)}) |
|
print(f"encoder run take {1000 * (time.time() - start):.2f}ms") |
|
|
|
audio_len = audio_len[0] |
|
actual_size = z_p.shape[-1] |
|
dec_slice_num = int(np.ceil(actual_size / dec_len)) |
|
|
|
z_p = np.pad(z_p, pad_width=((0,0),(0,0),(0, dec_slice_num * dec_len - actual_size)), mode="constant", constant_values=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pron_num = generate_word_pron_num(pronoun_lens, pron_slices) |
|
|
|
|
|
sub_audio_list = [] |
|
pron_num_slices, zp_slices, strip_flags, pron_lens, is_long = \ |
|
generate_decode_slices(pron_num, dec_len) |
|
|
|
for i in range(len(pron_num_slices)): |
|
pron_start, pron_end = pron_num_slices[i] |
|
zp_start, zp_end = zp_slices[i] |
|
|
|
phone_strs = [] |
|
for n in range(pron_start, pron_end): |
|
phone_strs.extend(phone_str[pron_slices[n]]) |
|
|
|
|
|
if is_long[i]: |
|
sub_audio_list.extend(decode_long_word(decoder, z_p[..., zp_start : zp_end], g, dec_len)) |
|
else: |
|
sub_dec_len = zp_end - zp_start |
|
sub_audio_len = 512 * sub_dec_len |
|
|
|
zp_slice = z_p[..., zp_start : zp_end] |
|
if zp_slice.shape[-1] < dec_len: |
|
zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1) |
|
|
|
start = time.time() |
|
outputs = decoder.inference(inputs=[zp_slice, g]) |
|
audio = outputs[0].flatten() |
|
audio = audio[:sub_audio_len] |
|
print(f"Sentence[{n}] Slice[{i}]: decoder run take {1000 * (time.time() - start):.2f}ms") |
|
|
|
if strip_flags[i][0]: |
|
|
|
head = 512 * pron_num[pron_start] |
|
|
|
audio = audio[head : ] |
|
if strip_flags[i][1]: |
|
|
|
tail = 512 * pron_num[pron_end - 1] |
|
|
|
audio = audio[: -tail] |
|
|
|
sub_audio_list.append(audio) |
|
|
|
sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len) |
|
audio_list.append(sub_audio) |
|
|
|
audio = audio_numpy_concat(audio_list, sr=sample_rate, speed=args.speed) |
|
soundfile.write(args.wav, audio, sample_rate) |
|
print(f"Save to {args.wav}") |
|
|
|
|
|
decoder.release() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|