Antoni Bigata commited on
Commit
4fd1a69
·
1 Parent(s): fc0dc6f

requirements

Browse files
Files changed (2) hide show
  1. WavLM.py +1 -1
  2. sgm/models/diffusion.py +5 -5
WavLM.py CHANGED
@@ -48,7 +48,7 @@ class WavLM_wrapper(nn.Module):
48
  )
49
  if not os.path.exists(model_path):
50
  self.download_model(model_path, model_size)
51
- checkpoint = torch.load(model_path)
52
  cfg = WavLMConfig(checkpoint["cfg"])
53
  self.cfg = cfg
54
  self.model = WavLM(cfg)
 
48
  )
49
  if not os.path.exists(model_path):
50
  self.download_model(model_path, model_size)
51
+ checkpoint = torch.load(model_path, weights_only=False)
52
  cfg = WavLMConfig(checkpoint["cfg"])
53
  self.cfg = cfg
54
  self.model = WavLM(cfg)
sgm/models/diffusion.py CHANGED
@@ -119,7 +119,7 @@ class DiffusionEngine(pl.LightningModule):
119
  pattern_to_remove=pattern_to_remove,
120
  )
121
  if separate_unet_ckpt is not None:
122
- sd = torch.load(separate_unet_ckpt)["state_dict"]
123
  if remove_keys_from_unet_weights is not None:
124
  for k in list(sd.keys()):
125
  for remove_key in remove_keys_from_unet_weights:
@@ -190,7 +190,7 @@ class DiffusionEngine(pl.LightningModule):
190
 
191
  def load_bad_model_weights(self, path: str) -> None:
192
  print(f"Restoring bad model from {path}")
193
- state_dict = torch.load(path, map_location="cpu")
194
  new_dict = {}
195
  for k, v in state_dict["module"].items():
196
  if "learned_mask" in k:
@@ -221,13 +221,13 @@ class DiffusionEngine(pl.LightningModule):
221
  ) -> None:
222
  print(f"Restoring from {path}")
223
  if path.endswith("ckpt"):
224
- sd = torch.load(path, map_location="cpu")["state_dict"]
225
  elif path.endswith("pt"):
226
- sd = torch.load(path, map_location="cpu")["module"]
227
  # Remove leading _forward_module from keys
228
  sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
229
  elif path.endswith("bin"):
230
- sd = torch.load(path, map_location="cpu")
231
  # Remove leading _forward_module from keys
232
  sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
233
  elif path.endswith("safetensors"):
 
119
  pattern_to_remove=pattern_to_remove,
120
  )
121
  if separate_unet_ckpt is not None:
122
+ sd = torch.load(separate_unet_ckpt, weights_only=False)["state_dict"]
123
  if remove_keys_from_unet_weights is not None:
124
  for k in list(sd.keys()):
125
  for remove_key in remove_keys_from_unet_weights:
 
190
 
191
  def load_bad_model_weights(self, path: str) -> None:
192
  print(f"Restoring bad model from {path}")
193
+ state_dict = torch.load(path, map_location="cpu", weights_only=False)
194
  new_dict = {}
195
  for k, v in state_dict["module"].items():
196
  if "learned_mask" in k:
 
221
  ) -> None:
222
  print(f"Restoring from {path}")
223
  if path.endswith("ckpt"):
224
+ sd = torch.load(path, map_location="cpu", weights_only=False)["state_dict"]
225
  elif path.endswith("pt"):
226
+ sd = torch.load(path, map_location="cpu", weights_only=False)["module"]
227
  # Remove leading _forward_module from keys
228
  sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
229
  elif path.endswith("bin"):
230
+ sd = torch.load(path, map_location="cpu", weights_only=False)
231
  # Remove leading _forward_module from keys
232
  sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
233
  elif path.endswith("safetensors"):