Spaces:
Sleeping
Sleeping
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: MIT | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a | |
# copy of this software and associated documentation files (the "Software"), | |
# to deal in the Software without restriction, including without limitation | |
# the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
# and/or sell copies of the Software, and to permit persons to whom the | |
# Software is furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL | |
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
# DEALINGS IN THE SOFTWARE. | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from common import get_mask_from_lengths | |
def compute_flow_loss( | |
z, log_det_W_list, log_s_list, n_elements, n_dims, mask, sigma=1.0 | |
): | |
log_det_W_total = 0.0 | |
for i, log_s in enumerate(log_s_list): | |
if i == 0: | |
log_s_total = torch.sum(log_s * mask) | |
if len(log_det_W_list): | |
log_det_W_total = log_det_W_list[i] | |
else: | |
log_s_total = log_s_total + torch.sum(log_s * mask) | |
if len(log_det_W_list): | |
log_det_W_total += log_det_W_list[i] | |
if len(log_det_W_list): | |
log_det_W_total *= n_elements | |
z = z * mask | |
prior_NLL = torch.sum(z * z) / (2 * sigma * sigma) | |
loss = prior_NLL - log_s_total - log_det_W_total | |
denom = n_elements * n_dims | |
loss = loss / denom | |
loss_prior = prior_NLL / denom | |
return loss, loss_prior | |
def compute_regression_loss(x_hat, x, mask, name=False): | |
x = x[:, None] if len(x.shape) == 2 else x # add channel dim | |
mask = mask[:, None] if len(mask.shape) == 2 else mask # add channel dim | |
assert len(x.shape) == len(mask.shape) | |
x = x * mask | |
x_hat = x_hat * mask | |
if name == "vpred": | |
loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction="sum") | |
else: | |
loss = F.mse_loss(x_hat, x, reduction="sum") | |
loss = loss / mask.sum() | |
loss_dict = {"loss_{}".format(name): loss} | |
return loss_dict | |
class AttributePredictionLoss(torch.nn.Module): | |
def __init__(self, name, model_config, loss_weight, sigma=1.0): | |
super(AttributePredictionLoss, self).__init__() | |
self.name = name | |
self.sigma = sigma | |
self.model_name = model_config["name"] | |
self.loss_weight = loss_weight | |
self.n_group_size = 1 | |
if "n_group_size" in model_config["hparams"]: | |
self.n_group_size = model_config["hparams"]["n_group_size"] | |
def forward(self, model_output, lens): | |
mask = get_mask_from_lengths(lens // self.n_group_size) | |
mask = mask[:, None].float() | |
loss_dict = {} | |
if "z" in model_output: | |
n_elements = lens.sum() // self.n_group_size | |
n_dims = model_output["z"].size(1) | |
loss, loss_prior = compute_flow_loss( | |
model_output["z"], | |
model_output["log_det_W_list"], | |
model_output["log_s_list"], | |
n_elements, | |
n_dims, | |
mask, | |
self.sigma, | |
) | |
loss_dict = { | |
"loss_{}".format(self.name): (loss, self.loss_weight), | |
"loss_prior_{}".format(self.name): (loss_prior, 0.0), | |
} | |
elif "x_hat" in model_output: | |
loss_dict = compute_regression_loss( | |
model_output["x_hat"], model_output["x"], mask, self.name | |
) | |
for k, v in loss_dict.items(): | |
loss_dict[k] = (v, self.loss_weight) | |
if len(loss_dict) == 0: | |
raise Exception("loss not supported") | |
return loss_dict | |
class AttentionCTCLoss(torch.nn.Module): | |
def __init__(self, blank_logprob=-1): | |
super(AttentionCTCLoss, self).__init__() | |
self.log_softmax = torch.nn.LogSoftmax(dim=3) | |
self.blank_logprob = blank_logprob | |
self.CTCLoss = nn.CTCLoss(zero_infinity=True) | |
def forward(self, attn_logprob, in_lens, out_lens): | |
key_lens = in_lens | |
query_lens = out_lens | |
attn_logprob_padded = F.pad( | |
input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob | |
) | |
cost_total = 0.0 | |
for bid in range(attn_logprob.shape[0]): | |
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) | |
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[ | |
: query_lens[bid], :, : key_lens[bid] + 1 | |
] | |
curr_logprob = self.log_softmax(curr_logprob[None])[0] | |
ctc_cost = self.CTCLoss( | |
curr_logprob, | |
target_seq, | |
input_lengths=query_lens[bid : bid + 1], | |
target_lengths=key_lens[bid : bid + 1], | |
) | |
cost_total += ctc_cost | |
cost = cost_total / attn_logprob.shape[0] | |
return cost | |
class AttentionBinarizationLoss(torch.nn.Module): | |
def __init__(self): | |
super(AttentionBinarizationLoss, self).__init__() | |
def forward(self, hard_attention, soft_attention): | |
log_sum = torch.log(soft_attention[hard_attention == 1]).sum() | |
return -log_sum / hard_attention.sum() | |
class RADTTSLoss(torch.nn.Module): | |
def __init__( | |
self, | |
sigma=1.0, | |
n_group_size=1, | |
dur_model_config=None, | |
f0_model_config=None, | |
energy_model_config=None, | |
vpred_model_config=None, | |
loss_weights=None, | |
): | |
super(RADTTSLoss, self).__init__() | |
self.sigma = sigma | |
self.n_group_size = n_group_size | |
self.loss_weights = loss_weights | |
self.attn_ctc_loss = AttentionCTCLoss( | |
blank_logprob=loss_weights.get("blank_logprob", -1) | |
) | |
self.loss_fns = {} | |
if dur_model_config is not None: | |
self.loss_fns["duration_model_outputs"] = AttributePredictionLoss( | |
"duration", dur_model_config, loss_weights["dur_loss_weight"] | |
) | |
if f0_model_config is not None: | |
self.loss_fns["f0_model_outputs"] = AttributePredictionLoss( | |
"f0", f0_model_config, loss_weights["f0_loss_weight"], sigma=1.0 | |
) | |
if energy_model_config is not None: | |
self.loss_fns["energy_model_outputs"] = AttributePredictionLoss( | |
"energy", energy_model_config, loss_weights["energy_loss_weight"] | |
) | |
if vpred_model_config is not None: | |
self.loss_fns["vpred_model_outputs"] = AttributePredictionLoss( | |
"vpred", vpred_model_config, loss_weights["vpred_loss_weight"] | |
) | |
def forward(self, model_output, in_lens, out_lens): | |
loss_dict = {} | |
if len(model_output["z_mel"]): | |
n_elements = out_lens.sum() // self.n_group_size | |
mask = get_mask_from_lengths(out_lens // self.n_group_size) | |
mask = mask[:, None].float() | |
n_dims = model_output["z_mel"].size(1) | |
loss_mel, loss_prior_mel = compute_flow_loss( | |
model_output["z_mel"], | |
model_output["log_det_W_list"], | |
model_output["log_s_list"], | |
n_elements, | |
n_dims, | |
mask, | |
self.sigma, | |
) | |
loss_dict["loss_mel"] = (loss_mel, 1.0) # loss, weight | |
loss_dict["loss_prior_mel"] = (loss_prior_mel, 0.0) | |
ctc_cost = self.attn_ctc_loss(model_output["attn_logprob"], in_lens, out_lens) | |
loss_dict["loss_ctc"] = (ctc_cost, self.loss_weights["ctc_loss_weight"]) | |
for k in model_output: | |
if k in self.loss_fns: | |
if model_output[k] is not None and len(model_output[k]) > 0: | |
t_lens = in_lens if "dur" in k else out_lens | |
mout = model_output[k] | |
for loss_name, v in self.loss_fns[k](mout, t_lens).items(): | |
loss_dict[loss_name] = v | |
return loss_dict | |