Yehor's picture
Init
ea6a7ed
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
# Based on https://github.com/NVIDIA/flowtron/blob/master/data.py
# Original license text:
###############################################################################
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
###############################################################################
import os
import argparse
import json
import numpy as np
import lmdb
import pickle as pkl
import torch
import torch.utils.data
from scipy.io.wavfile import read
from audio_processing import TacotronSTFT
from tts_text_processing.text_processing import TextProcessing
from scipy.stats import betabinom
from librosa import pyin
from common import update_params
from scipy.ndimage import distance_transform_edt as distance_transform
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
P = phoneme_count
M = mel_count
x = np.arange(0, P)
mel_text_probs = []
for i in range(1, M + 1):
a, b = scaling_factor * i, scaling_factor * (M + 1 - i)
rv = betabinom(P - 1, a, b)
mel_i_prob = rv.pmf(x)
mel_text_probs.append(mel_i_prob)
return torch.tensor(np.array(mel_text_probs))
def load_wav_to_torch(full_path):
"""Loads wavdata into torch array"""
sampling_rate, data = read(full_path)
return torch.from_numpy(np.array(data)).float(), sampling_rate
class Data(torch.utils.data.Dataset):
def __init__(
self,
datasets,
filter_length,
hop_length,
win_length,
sampling_rate,
n_mel_channels,
mel_fmin,
mel_fmax,
f0_min,
f0_max,
max_wav_value,
use_f0,
use_energy_avg,
use_log_f0,
use_scaled_energy,
symbol_set,
cleaner_names,
heteronyms_path,
phoneme_dict_path,
p_phoneme,
handle_phoneme="word",
handle_phoneme_ambiguous="ignore",
speaker_ids=None,
include_speakers=None,
n_frames=-1,
use_attn_prior_masking=True,
prepend_space_to_text=True,
append_space_to_text=True,
add_bos_eos_to_text=False,
betabinom_cache_path="",
betabinom_scaling_factor=0.05,
lmdb_cache_path="",
dur_min=None,
dur_max=None,
combine_speaker_and_emotion=False,
**kwargs,
):
self.combine_speaker_and_emotion = combine_speaker_and_emotion
self.max_wav_value = max_wav_value
self.audio_lmdb_dict = {} # dictionary of lmdbs for audio data
self.data = self.load_data(datasets)
self.distance_tx_unvoiced = False
if "distance_tx_unvoiced" in kwargs.keys():
self.distance_tx_unvoiced = kwargs["distance_tx_unvoiced"]
self.stft = TacotronSTFT(
filter_length=filter_length,
hop_length=hop_length,
win_length=win_length,
sampling_rate=sampling_rate,
n_mel_channels=n_mel_channels,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax,
)
self.do_mel_scaling = kwargs.get("do_mel_scaling", True)
self.mel_noise_scale = kwargs.get("mel_noise_scale", 0.0)
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.f0_min = f0_min
self.f0_max = f0_max
self.use_f0 = use_f0
self.use_log_f0 = use_log_f0
self.use_energy_avg = use_energy_avg
self.use_scaled_energy = use_scaled_energy
self.sampling_rate = sampling_rate
self.tp = TextProcessing(
symbol_set,
cleaner_names,
heteronyms_path,
phoneme_dict_path,
p_phoneme=p_phoneme,
handle_phoneme=handle_phoneme,
handle_phoneme_ambiguous=handle_phoneme_ambiguous,
prepend_space_to_text=prepend_space_to_text,
append_space_to_text=append_space_to_text,
add_bos_eos_to_text=add_bos_eos_to_text,
)
self.dur_min = dur_min
self.dur_max = dur_max
if speaker_ids is None or speaker_ids == "":
self.speaker_ids = self.create_speaker_lookup_table(self.data)
else:
self.speaker_ids = speaker_ids
print("Number of files", len(self.data))
if include_speakers is not None:
for speaker_set, include in include_speakers:
self.filter_by_speakers_(speaker_set, include)
print("Number of files after speaker filtering", len(self.data))
if dur_min is not None and dur_max is not None:
self.filter_by_duration_(dur_min, dur_max)
print("Number of files after duration filtering", len(self.data))
self.use_attn_prior_masking = bool(use_attn_prior_masking)
self.prepend_space_to_text = bool(prepend_space_to_text)
self.append_space_to_text = bool(append_space_to_text)
self.betabinom_cache_path = betabinom_cache_path
self.betabinom_scaling_factor = betabinom_scaling_factor
self.lmdb_cache_path = lmdb_cache_path
if self.lmdb_cache_path != "":
self.cache_data_lmdb = lmdb.open(
self.lmdb_cache_path, readonly=True, max_readers=1024, lock=False
).begin()
# # make sure caching path exists
# if not os.path.exists(self.betabinom_cache_path):
# os.makedirs(self.betabinom_cache_path)
print("Dataloader initialized with no augmentations")
self.speaker_map = None
if "speaker_map" in kwargs:
self.speaker_map = kwargs["speaker_map"]
def load_data(self, datasets, split="|"):
dataset = []
for dset_name, dset_dict in datasets.items():
folder_path = dset_dict["basedir"]
audiodir = dset_dict["audiodir"]
filename = dset_dict["filelist"]
audio_lmdb_key = None
if "lmdbpath" in dset_dict.keys() and len(dset_dict["lmdbpath"]) > 0:
self.audio_lmdb_dict[dset_name] = lmdb.open(
dset_dict["lmdbpath"], readonly=True, max_readers=256, lock=False
).begin()
audio_lmdb_key = dset_name
wav_folder_prefix = os.path.join(folder_path, audiodir)
filelist_path = os.path.join(folder_path, filename)
with open(filelist_path, encoding="utf-8") as f:
data = [line.strip().split(split) for line in f]
for d in data:
emotion = "other" if len(d) == 3 else d[3]
duration = -1 if len(d) == 3 else d[4]
dataset.append(
{
"audiopath": os.path.join(wav_folder_prefix, d[0]),
"text": d[1],
"speaker": d[2] + "-" + emotion
if self.combine_speaker_and_emotion
else d[2],
"emotion": emotion,
"duration": float(duration),
"lmdb_key": audio_lmdb_key,
}
)
return dataset
def filter_by_speakers_(self, speakers, include=True):
print("Include spaker {}: {}".format(speakers, include))
if include:
self.data = [x for x in self.data if x["speaker"] in speakers]
else:
self.data = [x for x in self.data if x["speaker"] not in speakers]
def filter_by_duration_(self, dur_min, dur_max):
self.data = [
x
for x in self.data
if x["duration"] == -1
or (x["duration"] >= dur_min and x["duration"] <= dur_max)
]
def create_speaker_lookup_table(self, data):
speaker_ids = np.sort(np.unique([x["speaker"] for x in data]))
d = {speaker_ids[i]: i for i in range(len(speaker_ids))}
print("Number of speakers:", len(d))
print("Speaker IDS", d)
return d
def f0_normalize(self, x):
if self.use_log_f0:
mask = x >= self.f0_min
x[mask] = torch.log(x[mask])
x[~mask] = 0.0
return x
def f0_denormalize(self, x):
if self.use_log_f0:
log_f0_min = np.log(self.f0_min)
mask = x >= log_f0_min
x[mask] = torch.exp(x[mask])
x[~mask] = 0.0
x[x <= 0.0] = 0.0
return x
def energy_avg_normalize(self, x):
if self.use_scaled_energy:
x = (x + 20.0) / 20.0
return x
def energy_avg_denormalize(self, x):
if self.use_scaled_energy:
x = x * 20.0 - 20.0
return x
def get_f0_pvoiced(
self,
audio,
sampling_rate=22050,
frame_length=1024,
hop_length=256,
f0_min=100,
f0_max=300,
):
audio_norm = audio / self.max_wav_value
f0, voiced_mask, p_voiced = pyin(
audio_norm,
f0_min,
f0_max,
sampling_rate,
frame_length=frame_length,
win_length=frame_length // 2,
hop_length=hop_length,
)
f0[~voiced_mask] = 0.0
f0 = torch.FloatTensor(f0)
p_voiced = torch.FloatTensor(p_voiced)
voiced_mask = torch.FloatTensor(voiced_mask)
return f0, voiced_mask, p_voiced
def get_energy_average(self, mel):
energy_avg = mel.mean(0)
energy_avg = self.energy_avg_normalize(energy_avg)
return energy_avg
def get_mel(self, audio):
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = self.stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
if self.do_mel_scaling:
melspec = (melspec + 5.5) / 2
if self.mel_noise_scale > 0:
melspec += torch.randn_like(melspec) * self.mel_noise_scale
return melspec
def get_speaker_id(self, speaker):
if self.speaker_map is not None and speaker in self.speaker_map:
speaker = self.speaker_map[speaker]
return torch.LongTensor([self.speaker_ids[speaker]])
def get_text(self, text):
text = self.tp.encode_text(text)
text = torch.LongTensor(text)
return text
def get_attention_prior(self, n_tokens, n_frames):
# cache the entire attn_prior by filename
if self.use_attn_prior_masking:
filename = "{}_{}".format(n_tokens, n_frames)
prior_path = os.path.join(self.betabinom_cache_path, filename)
prior_path += "_prior.pth"
if self.lmdb_cache_path != "":
attn_prior = pkl.loads(
self.cache_data_lmdb.get(prior_path.encode("ascii"))
)
elif os.path.exists(prior_path):
attn_prior = torch.load(prior_path)
else:
attn_prior = beta_binomial_prior_distribution(
n_tokens, n_frames, self.betabinom_scaling_factor
)
torch.save(attn_prior, prior_path)
else:
attn_prior = torch.ones(n_frames, n_tokens) # all ones baseline
return attn_prior
def __getitem__(self, index):
data = self.data[index]
audiopath, text = data["audiopath"], data["text"]
speaker_id = data["speaker"]
if data["lmdb_key"] is not None:
data_dict = pkl.loads(
self.audio_lmdb_dict[data["lmdb_key"]].get(audiopath.encode("ascii"))
)
audio = data_dict["audio"]
sampling_rate = data_dict["sampling_rate"]
else:
audio, sampling_rate = load_wav_to_torch(audiopath)
if sampling_rate != self.sampling_rate:
raise ValueError(
"{} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate
)
)
mel = self.get_mel(audio)
f0 = None
p_voiced = None
voiced_mask = None
if self.use_f0:
filename = "_".join(audiopath.split("/")[-3:])
f0_path = os.path.join(self.betabinom_cache_path, filename)
f0_path += "_f0_sr{}_fl{}_hl{}_f0min{}_f0max{}_log{}.pt".format(
self.sampling_rate,
self.filter_length,
self.hop_length,
self.f0_min,
self.f0_max,
self.use_log_f0,
)
dikt = None
if len(self.lmdb_cache_path) > 0:
dikt = pkl.loads(self.cache_data_lmdb.get(f0_path.encode("ascii")))
f0 = dikt["f0"]
p_voiced = dikt["p_voiced"]
voiced_mask = dikt["voiced_mask"]
elif os.path.exists(f0_path):
try:
dikt = torch.load(f0_path)
except:
print(f"f0 loading from {f0_path} is broken, recomputing.")
if dikt is not None:
f0 = dikt["f0"]
p_voiced = dikt["p_voiced"]
voiced_mask = dikt["voiced_mask"]
else:
f0, voiced_mask, p_voiced = self.get_f0_pvoiced(
audio.cpu().numpy(),
self.sampling_rate,
self.filter_length,
self.hop_length,
self.f0_min,
self.f0_max,
)
print("saving f0 to {}".format(f0_path))
torch.save(
{"f0": f0, "voiced_mask": voiced_mask, "p_voiced": p_voiced},
f0_path,
)
if f0 is None:
raise Exception("STOP, BROKEN F0 {}".format(audiopath))
f0 = self.f0_normalize(f0)
if self.distance_tx_unvoiced:
mask = f0 <= 0.0
distance_map = np.log(distance_transform(mask))
distance_map[distance_map <= 0] = 0.0
f0 = f0 - distance_map
energy_avg = None
if self.use_energy_avg:
energy_avg = self.get_energy_average(mel)
if self.use_scaled_energy and energy_avg.min() < 0.0:
print(audiopath, "has scaled energy avg smaller than 0")
speaker_id = self.get_speaker_id(speaker_id)
text_encoded = self.get_text(text)
attn_prior = self.get_attention_prior(text_encoded.shape[0], mel.shape[1])
if not self.use_attn_prior_masking:
attn_prior = None
return {
"mel": mel,
"speaker_id": speaker_id,
"text_encoded": text_encoded,
"audiopath": audiopath,
"attn_prior": attn_prior,
"f0": f0,
"p_voiced": p_voiced,
"voiced_mask": voiced_mask,
"energy_avg": energy_avg,
}
def __len__(self):
return len(self.data)
class DataCollate:
"""Zero-pads model inputs and targets given number of steps"""
def __init__(self, n_frames_per_step=1):
self.n_frames_per_step = n_frames_per_step
def __call__(self, batch):
"""Collate from normalized data"""
# Right zero-pad all one-hot text sequences to max input length
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x["text_encoded"]) for x in batch]),
dim=0,
descending=True,
)
max_input_len = input_lengths[0]
text_padded = torch.LongTensor(len(batch), max_input_len)
text_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]]["text_encoded"]
text_padded[i, : text.size(0)] = text
# Right zero-pad mel-spec
num_mel_channels = batch[0]["mel"].size(0)
max_target_len = max([x["mel"].size(1) for x in batch])
# include mel padded, gate padded and speaker ids
mel_padded = torch.FloatTensor(len(batch), num_mel_channels, max_target_len)
mel_padded.zero_()
f0_padded = None
p_voiced_padded = None
voiced_mask_padded = None
energy_avg_padded = None
if batch[0]["f0"] is not None:
f0_padded = torch.FloatTensor(len(batch), max_target_len)
f0_padded.zero_()
if batch[0]["p_voiced"] is not None:
p_voiced_padded = torch.FloatTensor(len(batch), max_target_len)
p_voiced_padded.zero_()
if batch[0]["voiced_mask"] is not None:
voiced_mask_padded = torch.FloatTensor(len(batch), max_target_len)
voiced_mask_padded.zero_()
if batch[0]["energy_avg"] is not None:
energy_avg_padded = torch.FloatTensor(len(batch), max_target_len)
energy_avg_padded.zero_()
attn_prior_padded = torch.FloatTensor(len(batch), max_target_len, max_input_len)
attn_prior_padded.zero_()
output_lengths = torch.LongTensor(len(batch))
speaker_ids = torch.LongTensor(len(batch))
audiopaths = []
for i in range(len(ids_sorted_decreasing)):
mel = batch[ids_sorted_decreasing[i]]["mel"]
mel_padded[i, :, : mel.size(1)] = mel
if batch[ids_sorted_decreasing[i]]["f0"] is not None:
f0 = batch[ids_sorted_decreasing[i]]["f0"]
f0_padded[i, : len(f0)] = f0
if batch[ids_sorted_decreasing[i]]["voiced_mask"] is not None:
voiced_mask = batch[ids_sorted_decreasing[i]]["voiced_mask"]
voiced_mask_padded[i, : len(f0)] = voiced_mask
if batch[ids_sorted_decreasing[i]]["p_voiced"] is not None:
p_voiced = batch[ids_sorted_decreasing[i]]["p_voiced"]
p_voiced_padded[i, : len(f0)] = p_voiced
if batch[ids_sorted_decreasing[i]]["energy_avg"] is not None:
energy_avg = batch[ids_sorted_decreasing[i]]["energy_avg"]
energy_avg_padded[i, : len(energy_avg)] = energy_avg
output_lengths[i] = mel.size(1)
speaker_ids[i] = batch[ids_sorted_decreasing[i]]["speaker_id"]
audiopath = batch[ids_sorted_decreasing[i]]["audiopath"]
audiopaths.append(audiopath)
cur_attn_prior = batch[ids_sorted_decreasing[i]]["attn_prior"]
if cur_attn_prior is None:
attn_prior_padded = None
else:
attn_prior_padded[
i, : cur_attn_prior.size(0), : cur_attn_prior.size(1)
] = cur_attn_prior
return {
"mel": mel_padded,
"speaker_ids": speaker_ids,
"text": text_padded,
"input_lengths": input_lengths,
"output_lengths": output_lengths,
"audiopaths": audiopaths,
"attn_prior": attn_prior_padded,
"f0": f0_padded,
"p_voiced": p_voiced_padded,
"voiced_mask": voiced_mask_padded,
"energy_avg": energy_avg_padded,
}
# ===================================================================
# Takes directory of clean audio and makes directory of spectrograms
# Useful for making test sets
# ===================================================================
if __name__ == "__main__":
# Get defaults so it can work with no Sacred
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, help="JSON file for configuration")
parser.add_argument("-p", "--params", nargs="+", default=[])
args = parser.parse_args()
args.rank = 0
# Parse configs. Globals nicer in this case
with open(args.config) as f:
data = f.read()
config = json.loads(data)
update_params(config, args.params)
print(config)
data_config = config["data_config"]
ignore_keys = ["training_files", "validation_files"]
trainset = Data(
data_config["training_files"],
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
)
valset = Data(
data_config["validation_files"],
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
speaker_ids=trainset.speaker_ids,
)
collate_fn = DataCollate()
for dataset in (trainset, valset):
for i, batch in enumerate(dataset):
out = batch
print("{}/{}".format(i, len(dataset)))