Spaces:
Running
Running
float32
Browse files- app.py +2 -8
- src/e2_tts_pytorch/e2_tts_crossatt3.py +2 -2
app.py
CHANGED
@@ -201,13 +201,7 @@ def load(device):
|
|
201 |
return e2tts, stft
|
202 |
|
203 |
|
204 |
-
import copy
|
205 |
-
|
206 |
e2tts, stft = load(device)
|
207 |
-
video2roll_net = copy.deepcopy(e2tts.video2roll_net)
|
208 |
-
e2tts = e2tts.half()
|
209 |
-
e2tts.video2roll_net = video2roll_net
|
210 |
-
del video2roll_net
|
211 |
gc.collect()
|
212 |
|
213 |
|
@@ -262,7 +256,7 @@ def run(e2tts, stft, arg1, arg2, arg3, arg4, piano):
|
|
262 |
|
263 |
l = mel_lengths[0]
|
264 |
#cond = mel_spec.repeat(num, 1, 1)
|
265 |
-
cond = torch.randn(num, l, e2tts.num_channels)
|
266 |
duration = torch.tensor([l]*num, dtype=torch.int32)
|
267 |
lens = torch.tensor([l]*num, dtype=torch.int32)
|
268 |
print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "start")
|
@@ -280,7 +274,7 @@ def run(e2tts, stft, arg1, arg2, arg3, arg4, piano):
|
|
280 |
|
281 |
outputs = outputs.reshape(1, -1, outputs.shape[-1])
|
282 |
audio_final = e2tts.vocos.decode(outputs.transpose(-1,-2))
|
283 |
-
audio_final = audio_final.detach().cpu()
|
284 |
|
285 |
torchaudio.save(audio_path, audio_final, sample_rate = e2tts.sampling_rate)
|
286 |
|
|
|
201 |
return e2tts, stft
|
202 |
|
203 |
|
|
|
|
|
204 |
e2tts, stft = load(device)
|
|
|
|
|
|
|
|
|
205 |
gc.collect()
|
206 |
|
207 |
|
|
|
256 |
|
257 |
l = mel_lengths[0]
|
258 |
#cond = mel_spec.repeat(num, 1, 1)
|
259 |
+
cond = torch.randn(num, l, e2tts.num_channels)
|
260 |
duration = torch.tensor([l]*num, dtype=torch.int32)
|
261 |
lens = torch.tensor([l]*num, dtype=torch.int32)
|
262 |
print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "start")
|
|
|
274 |
|
275 |
outputs = outputs.reshape(1, -1, outputs.shape[-1])
|
276 |
audio_final = e2tts.vocos.decode(outputs.transpose(-1,-2))
|
277 |
+
audio_final = audio_final.detach().cpu()
|
278 |
|
279 |
torchaudio.save(audio_path, audio_final, sample_rate = e2tts.sampling_rate)
|
280 |
|
src/e2_tts_pytorch/e2_tts_crossatt3.py
CHANGED
@@ -2162,12 +2162,12 @@ class E2TTS(Module):
|
|
2162 |
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
2163 |
|
2164 |
if frames is None:
|
2165 |
-
frames_embed = torch.zeros(batch, cond_seq_len, NOTES, device=device)
|
2166 |
else:
|
2167 |
#### sampling settings
|
2168 |
train_video_encoder = True
|
2169 |
if train_video_encoder:
|
2170 |
-
frames_embed = self.encode_frames(frames, cond_seq_len)
|
2171 |
else:
|
2172 |
frames_embed = midis
|
2173 |
if frames_embed.shape[1] < cond_seq_len:
|
|
|
2162 |
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
2163 |
|
2164 |
if frames is None:
|
2165 |
+
frames_embed = torch.zeros(batch, cond_seq_len, NOTES, device=device)
|
2166 |
else:
|
2167 |
#### sampling settings
|
2168 |
train_video_encoder = True
|
2169 |
if train_video_encoder:
|
2170 |
+
frames_embed = self.encode_frames(frames, cond_seq_len)
|
2171 |
else:
|
2172 |
frames_embed = midis
|
2173 |
if frames_embed.shape[1] < cond_seq_len:
|