Update accelerate_train_second.py
Browse files
accelerate_train_second.py
CHANGED
@@ -424,7 +424,7 @@ def main(config_path):
|
|
424 |
features=ref, # reference from the same speaker as the embedding
|
425 |
embedding_mask_proba=0.1,
|
426 |
num_steps=num_steps).squeeze(1)
|
427 |
-
loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean()
|
428 |
loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
|
429 |
else:
|
430 |
s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
|
@@ -515,6 +515,12 @@ def main(config_path):
|
|
515 |
optimizer.zero_grad()
|
516 |
d_loss = dl(wav.detach(), y_rec.detach()).mean()
|
517 |
accelerator.backward(d_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
optimizer.step('msd')
|
519 |
optimizer.step('mpd')
|
520 |
else:
|
|
|
424 |
features=ref, # reference from the same speaker as the embedding
|
425 |
embedding_mask_proba=0.1,
|
426 |
num_steps=num_steps).squeeze(1)
|
427 |
+
loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
|
428 |
loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
|
429 |
else:
|
430 |
s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
|
|
|
515 |
optimizer.zero_grad()
|
516 |
d_loss = dl(wav.detach(), y_rec.detach()).mean()
|
517 |
accelerator.backward(d_loss)
|
518 |
+
|
519 |
+
# the biggest culprit of causing NaNs in DDP training!
|
520 |
+
accelerator.clip_grad_norm_(model.msd.parameters(), max_norm=2.0)
|
521 |
+
accelerator.clip_grad_norm_(model.mpd.parameters(), max_norm=2.0)
|
522 |
+
|
523 |
+
|
524 |
optimizer.step('msd')
|
525 |
optimizer.step('mpd')
|
526 |
else:
|