lshzhm commited on
Commit
7d26b43
·
1 Parent(s): 53593c1
Files changed (2) hide show
  1. app.py +2 -8
  2. src/e2_tts_pytorch/e2_tts_crossatt3.py +2 -2
app.py CHANGED
@@ -201,13 +201,7 @@ def load(device):
201
  return e2tts, stft
202
 
203
 
204
- import copy
205
-
206
  e2tts, stft = load(device)
207
- video2roll_net = copy.deepcopy(e2tts.video2roll_net)
208
- e2tts = e2tts.half()
209
- e2tts.video2roll_net = video2roll_net
210
- del video2roll_net
211
  gc.collect()
212
 
213
 
@@ -262,7 +256,7 @@ def run(e2tts, stft, arg1, arg2, arg3, arg4, piano):
262
 
263
  l = mel_lengths[0]
264
  #cond = mel_spec.repeat(num, 1, 1)
265
- cond = torch.randn(num, l, e2tts.num_channels).half()
266
  duration = torch.tensor([l]*num, dtype=torch.int32)
267
  lens = torch.tensor([l]*num, dtype=torch.int32)
268
  print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "start")
@@ -280,7 +274,7 @@ def run(e2tts, stft, arg1, arg2, arg3, arg4, piano):
280
 
281
  outputs = outputs.reshape(1, -1, outputs.shape[-1])
282
  audio_final = e2tts.vocos.decode(outputs.transpose(-1,-2))
283
- audio_final = audio_final.detach().cpu().float()
284
 
285
  torchaudio.save(audio_path, audio_final, sample_rate = e2tts.sampling_rate)
286
 
 
201
  return e2tts, stft
202
 
203
 
 
 
204
  e2tts, stft = load(device)
 
 
 
 
205
  gc.collect()
206
 
207
 
 
256
 
257
  l = mel_lengths[0]
258
  #cond = mel_spec.repeat(num, 1, 1)
259
+ cond = torch.randn(num, l, e2tts.num_channels)
260
  duration = torch.tensor([l]*num, dtype=torch.int32)
261
  lens = torch.tensor([l]*num, dtype=torch.int32)
262
  print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "start")
 
274
 
275
  outputs = outputs.reshape(1, -1, outputs.shape[-1])
276
  audio_final = e2tts.vocos.decode(outputs.transpose(-1,-2))
277
+ audio_final = audio_final.detach().cpu()
278
 
279
  torchaudio.save(audio_path, audio_final, sample_rate = e2tts.sampling_rate)
280
 
src/e2_tts_pytorch/e2_tts_crossatt3.py CHANGED
@@ -2162,12 +2162,12 @@ class E2TTS(Module):
2162
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
2163
 
2164
  if frames is None:
2165
- frames_embed = torch.zeros(batch, cond_seq_len, NOTES, device=device).half()
2166
  else:
2167
  #### sampling settings
2168
  train_video_encoder = True
2169
  if train_video_encoder:
2170
- frames_embed = self.encode_frames(frames, cond_seq_len).half()
2171
  else:
2172
  frames_embed = midis
2173
  if frames_embed.shape[1] < cond_seq_len:
 
2162
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
2163
 
2164
  if frames is None:
2165
+ frames_embed = torch.zeros(batch, cond_seq_len, NOTES, device=device)
2166
  else:
2167
  #### sampling settings
2168
  train_video_encoder = True
2169
  if train_video_encoder:
2170
+ frames_embed = self.encode_frames(frames, cond_seq_len)
2171
  else:
2172
  frames_embed = midis
2173
  if frames_embed.shape[1] < cond_seq_len: