Spaces:
Running
on
Zero
Running
on
Zero
Antoni Bigata
commited on
Commit
·
4fd1a69
1
Parent(s):
fc0dc6f
requirements
Browse files- WavLM.py +1 -1
- sgm/models/diffusion.py +5 -5
WavLM.py
CHANGED
@@ -48,7 +48,7 @@ class WavLM_wrapper(nn.Module):
|
|
48 |
)
|
49 |
if not os.path.exists(model_path):
|
50 |
self.download_model(model_path, model_size)
|
51 |
-
checkpoint = torch.load(model_path)
|
52 |
cfg = WavLMConfig(checkpoint["cfg"])
|
53 |
self.cfg = cfg
|
54 |
self.model = WavLM(cfg)
|
|
|
48 |
)
|
49 |
if not os.path.exists(model_path):
|
50 |
self.download_model(model_path, model_size)
|
51 |
+
checkpoint = torch.load(model_path, weights_only=False)
|
52 |
cfg = WavLMConfig(checkpoint["cfg"])
|
53 |
self.cfg = cfg
|
54 |
self.model = WavLM(cfg)
|
sgm/models/diffusion.py
CHANGED
@@ -119,7 +119,7 @@ class DiffusionEngine(pl.LightningModule):
|
|
119 |
pattern_to_remove=pattern_to_remove,
|
120 |
)
|
121 |
if separate_unet_ckpt is not None:
|
122 |
-
sd = torch.load(separate_unet_ckpt)["state_dict"]
|
123 |
if remove_keys_from_unet_weights is not None:
|
124 |
for k in list(sd.keys()):
|
125 |
for remove_key in remove_keys_from_unet_weights:
|
@@ -190,7 +190,7 @@ class DiffusionEngine(pl.LightningModule):
|
|
190 |
|
191 |
def load_bad_model_weights(self, path: str) -> None:
|
192 |
print(f"Restoring bad model from {path}")
|
193 |
-
state_dict = torch.load(path, map_location="cpu")
|
194 |
new_dict = {}
|
195 |
for k, v in state_dict["module"].items():
|
196 |
if "learned_mask" in k:
|
@@ -221,13 +221,13 @@ class DiffusionEngine(pl.LightningModule):
|
|
221 |
) -> None:
|
222 |
print(f"Restoring from {path}")
|
223 |
if path.endswith("ckpt"):
|
224 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
225 |
elif path.endswith("pt"):
|
226 |
-
sd = torch.load(path, map_location="cpu")["module"]
|
227 |
# Remove leading _forward_module from keys
|
228 |
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
229 |
elif path.endswith("bin"):
|
230 |
-
sd = torch.load(path, map_location="cpu")
|
231 |
# Remove leading _forward_module from keys
|
232 |
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
233 |
elif path.endswith("safetensors"):
|
|
|
119 |
pattern_to_remove=pattern_to_remove,
|
120 |
)
|
121 |
if separate_unet_ckpt is not None:
|
122 |
+
sd = torch.load(separate_unet_ckpt, weights_only=False)["state_dict"]
|
123 |
if remove_keys_from_unet_weights is not None:
|
124 |
for k in list(sd.keys()):
|
125 |
for remove_key in remove_keys_from_unet_weights:
|
|
|
190 |
|
191 |
def load_bad_model_weights(self, path: str) -> None:
|
192 |
print(f"Restoring bad model from {path}")
|
193 |
+
state_dict = torch.load(path, map_location="cpu", weights_only=False)
|
194 |
new_dict = {}
|
195 |
for k, v in state_dict["module"].items():
|
196 |
if "learned_mask" in k:
|
|
|
221 |
) -> None:
|
222 |
print(f"Restoring from {path}")
|
223 |
if path.endswith("ckpt"):
|
224 |
+
sd = torch.load(path, map_location="cpu", weights_only=False)["state_dict"]
|
225 |
elif path.endswith("pt"):
|
226 |
+
sd = torch.load(path, map_location="cpu", weights_only=False)["module"]
|
227 |
# Remove leading _forward_module from keys
|
228 |
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
229 |
elif path.endswith("bin"):
|
230 |
+
sd = torch.load(path, map_location="cpu", weights_only=False)
|
231 |
# Remove leading _forward_module from keys
|
232 |
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
233 |
elif path.endswith("safetensors"):
|