Respair commited on
Commit
2b7074e
·
verified ·
1 Parent(s): c8d3855

Update accelerate_train_second.py

Browse files
Files changed (1) hide show
  1. accelerate_train_second.py +7 -1
accelerate_train_second.py CHANGED
@@ -424,7 +424,7 @@ def main(config_path):
424
  features=ref, # reference from the same speaker as the embedding
425
  embedding_mask_proba=0.1,
426
  num_steps=num_steps).squeeze(1)
427
- loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
428
  loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
429
  else:
430
  s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
@@ -515,6 +515,12 @@ def main(config_path):
515
  optimizer.zero_grad()
516
  d_loss = dl(wav.detach(), y_rec.detach()).mean()
517
  accelerator.backward(d_loss)
 
 
 
 
 
 
518
  optimizer.step('msd')
519
  optimizer.step('mpd')
520
  else:
 
424
  features=ref, # reference from the same speaker as the embedding
425
  embedding_mask_proba=0.1,
426
  num_steps=num_steps).squeeze(1)
427
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
428
  loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
429
  else:
430
  s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
 
515
  optimizer.zero_grad()
516
  d_loss = dl(wav.detach(), y_rec.detach()).mean()
517
  accelerator.backward(d_loss)
518
+
519
+ # the biggest culprit of causing NaNs in DDP training!
520
+ accelerator.clip_grad_norm_(model.msd.parameters(), max_norm=2.0)
521
+ accelerator.clip_grad_norm_(model.mpd.parameters(), max_norm=2.0)
522
+
523
+
524
  optimizer.step('msd')
525
  optimizer.step('mpd')
526
  else: