ing0 commited on
Commit
69ef734
·
1 Parent(s): 45a38e5
Files changed (3) hide show
  1. app.py +19 -3
  2. diffrhythm/infer/infer.py +3 -2
  3. diffrhythm/model/cfm.py +4 -1
app.py CHANGED
@@ -31,7 +31,7 @@ cfm, tokenizer, muq, vae = prepare_model(device)
31
  cfm = torch.compile(cfm)
32
 
33
  @spaces.GPU(duration=20)
34
- def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, file_type='wav', max_frames=2048, device='cuda'):
35
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
@@ -56,10 +56,12 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
56
  style_prompt=style_prompt,
57
  negative_style_prompt=negative_style_prompt,
58
  steps=steps,
 
59
  sway_sampling_coef=sway_sampling_coef,
60
  start_time=start_time,
61
  file_type=file_type,
62
- vocal_flag=vocal_flag
 
63
  )
64
  return generated_song
65
 
@@ -223,6 +225,10 @@ with gr.Blocks(css=css) as demo:
223
  4. **Supported Languages**
224
  - **Chinese and English**
225
  - More languages comming soon
 
 
 
 
226
  """)
227
 
228
  lyrics_btn = gr.Button("Generate", variant="primary")
@@ -246,6 +252,16 @@ with gr.Blocks(css=css) as demo:
246
  interactive=True,
247
  elem_id="step_slider"
248
  )
 
 
 
 
 
 
 
 
 
 
249
  file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
250
 
251
 
@@ -387,7 +403,7 @@ with gr.Blocks(css=css) as demo:
387
 
388
  lyrics_btn.click(
389
  fn=infer_music,
390
- inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, file_type],
391
  outputs=audio_output
392
  )
393
 
 
31
  cfm = torch.compile(cfm)
32
 
33
  @spaces.GPU(duration=20)
34
+ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav', odeint_method='euler', max_frames=2048, device='cuda'):
35
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
 
56
  style_prompt=style_prompt,
57
  negative_style_prompt=negative_style_prompt,
58
  steps=steps,
59
+ cfg_strength=cfg_strength,
60
  sway_sampling_coef=sway_sampling_coef,
61
  start_time=start_time,
62
  file_type=file_type,
63
+ vocal_flag=vocal_flag,
64
+ odeint_method=odeint_method,
65
  )
66
  return generated_song
67
 
 
225
  4. **Supported Languages**
226
  - **Chinese and English**
227
  - More languages comming soon
228
+
229
+ 5. **Others**
230
+ - If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
231
+
232
  """)
233
 
234
  lyrics_btn = gr.Button("Generate", variant="primary")
 
252
  interactive=True,
253
  elem_id="step_slider"
254
  )
255
+ cfg_strength = gr.Slider(
256
+ minimum=1,
257
+ maximum=10,
258
+ value=4.0,
259
+ step=0.5,
260
+ label="CFG Strength",
261
+ interactive=True,
262
+ elem_id="step_slider"
263
+ )
264
+ odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
265
  file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
266
 
267
 
 
403
 
404
  lyrics_btn.click(
405
  fn=infer_music,
406
+ inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method],
407
  outputs=audio_output
408
  )
409
 
diffrhythm/infer/infer.py CHANGED
@@ -74,7 +74,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
74
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
75
  return y_final
76
 
77
- def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type, vocal_flag):
78
 
79
  with torch.inference_mode():
80
  generated, _ = cfm_model.sample(
@@ -84,10 +84,11 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
84
  style_prompt=style_prompt,
85
  negative_style_prompt=negative_style_prompt,
86
  steps=steps,
87
- cfg_strength=4.0,
88
  sway_sampling_coef=sway_sampling_coef,
89
  start_time=start_time,
90
  vocal_flag=vocal_flag,
 
91
  )
92
 
93
  generated = generated.to(torch.float32)
 
74
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
75
  return y_final
76
 
77
+ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, vocal_flag, odeint_method):
78
 
79
  with torch.inference_mode():
80
  generated, _ = cfm_model.sample(
 
84
  style_prompt=style_prompt,
85
  negative_style_prompt=negative_style_prompt,
86
  steps=steps,
87
+ cfg_strength=cfg_strength,
88
  sway_sampling_coef=sway_sampling_coef,
89
  start_time=start_time,
90
  vocal_flag=vocal_flag,
91
+ odeint_method=odeint_method,
92
  )
93
 
94
  generated = generated.to(torch.float32)
diffrhythm/model/cfm.py CHANGED
@@ -114,9 +114,12 @@ class CFM(nn.Module):
114
  start_time=None,
115
  latent_pred_start_frame=0,
116
  latent_pred_end_frame=2048,
117
- vocal_flag=False
 
118
  ):
119
  self.eval()
 
 
120
 
121
  if next(self.parameters()).dtype == torch.float16:
122
  cond = cond.half()
 
114
  start_time=None,
115
  latent_pred_start_frame=0,
116
  latent_pred_end_frame=2048,
117
+ vocal_flag=False,
118
+ odeint_method="euler"
119
  ):
120
  self.eval()
121
+
122
+ self.odeint_kwargs = dict(method=odeint_method)
123
 
124
  if next(self.parameters()).dtype == torch.float16:
125
  cond = cond.half()