Spaces:
Running
on
Zero
Running
on
Zero
cfg
Browse files- app.py +19 -3
- diffrhythm/infer/infer.py +3 -2
- diffrhythm/model/cfm.py +4 -1
app.py
CHANGED
@@ -31,7 +31,7 @@ cfm, tokenizer, muq, vae = prepare_model(device)
|
|
31 |
cfm = torch.compile(cfm)
|
32 |
|
33 |
@spaces.GPU(duration=20)
|
34 |
-
def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, file_type='wav', max_frames=2048, device='cuda'):
|
35 |
|
36 |
if randomize_seed:
|
37 |
seed = random.randint(0, MAX_SEED)
|
@@ -56,10 +56,12 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
|
|
56 |
style_prompt=style_prompt,
|
57 |
negative_style_prompt=negative_style_prompt,
|
58 |
steps=steps,
|
|
|
59 |
sway_sampling_coef=sway_sampling_coef,
|
60 |
start_time=start_time,
|
61 |
file_type=file_type,
|
62 |
-
vocal_flag=vocal_flag
|
|
|
63 |
)
|
64 |
return generated_song
|
65 |
|
@@ -223,6 +225,10 @@ with gr.Blocks(css=css) as demo:
|
|
223 |
4. **Supported Languages**
|
224 |
- **Chinese and English**
|
225 |
- More languages comming soon
|
|
|
|
|
|
|
|
|
226 |
""")
|
227 |
|
228 |
lyrics_btn = gr.Button("Generate", variant="primary")
|
@@ -246,6 +252,16 @@ with gr.Blocks(css=css) as demo:
|
|
246 |
interactive=True,
|
247 |
elem_id="step_slider"
|
248 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
|
250 |
|
251 |
|
@@ -387,7 +403,7 @@ with gr.Blocks(css=css) as demo:
|
|
387 |
|
388 |
lyrics_btn.click(
|
389 |
fn=infer_music,
|
390 |
-
inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, file_type],
|
391 |
outputs=audio_output
|
392 |
)
|
393 |
|
|
|
31 |
cfm = torch.compile(cfm)
|
32 |
|
33 |
@spaces.GPU(duration=20)
|
34 |
+
def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav', odeint_method='euler', max_frames=2048, device='cuda'):
|
35 |
|
36 |
if randomize_seed:
|
37 |
seed = random.randint(0, MAX_SEED)
|
|
|
56 |
style_prompt=style_prompt,
|
57 |
negative_style_prompt=negative_style_prompt,
|
58 |
steps=steps,
|
59 |
+
cfg_strength=cfg_strength,
|
60 |
sway_sampling_coef=sway_sampling_coef,
|
61 |
start_time=start_time,
|
62 |
file_type=file_type,
|
63 |
+
vocal_flag=vocal_flag,
|
64 |
+
odeint_method=odeint_method,
|
65 |
)
|
66 |
return generated_song
|
67 |
|
|
|
225 |
4. **Supported Languages**
|
226 |
- **Chinese and English**
|
227 |
- More languages comming soon
|
228 |
+
|
229 |
+
5. **Others**
|
230 |
+
- If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
|
231 |
+
|
232 |
""")
|
233 |
|
234 |
lyrics_btn = gr.Button("Generate", variant="primary")
|
|
|
252 |
interactive=True,
|
253 |
elem_id="step_slider"
|
254 |
)
|
255 |
+
cfg_strength = gr.Slider(
|
256 |
+
minimum=1,
|
257 |
+
maximum=10,
|
258 |
+
value=4.0,
|
259 |
+
step=0.5,
|
260 |
+
label="CFG Strength",
|
261 |
+
interactive=True,
|
262 |
+
elem_id="step_slider"
|
263 |
+
)
|
264 |
+
odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
|
265 |
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
|
266 |
|
267 |
|
|
|
403 |
|
404 |
lyrics_btn.click(
|
405 |
fn=infer_music,
|
406 |
+
inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method],
|
407 |
outputs=audio_output
|
408 |
)
|
409 |
|
diffrhythm/infer/infer.py
CHANGED
@@ -74,7 +74,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
|
|
74 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
75 |
return y_final
|
76 |
|
77 |
-
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type, vocal_flag):
|
78 |
|
79 |
with torch.inference_mode():
|
80 |
generated, _ = cfm_model.sample(
|
@@ -84,10 +84,11 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
|
|
84 |
style_prompt=style_prompt,
|
85 |
negative_style_prompt=negative_style_prompt,
|
86 |
steps=steps,
|
87 |
-
cfg_strength=
|
88 |
sway_sampling_coef=sway_sampling_coef,
|
89 |
start_time=start_time,
|
90 |
vocal_flag=vocal_flag,
|
|
|
91 |
)
|
92 |
|
93 |
generated = generated.to(torch.float32)
|
|
|
74 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
75 |
return y_final
|
76 |
|
77 |
+
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, vocal_flag, odeint_method):
|
78 |
|
79 |
with torch.inference_mode():
|
80 |
generated, _ = cfm_model.sample(
|
|
|
84 |
style_prompt=style_prompt,
|
85 |
negative_style_prompt=negative_style_prompt,
|
86 |
steps=steps,
|
87 |
+
cfg_strength=cfg_strength,
|
88 |
sway_sampling_coef=sway_sampling_coef,
|
89 |
start_time=start_time,
|
90 |
vocal_flag=vocal_flag,
|
91 |
+
odeint_method=odeint_method,
|
92 |
)
|
93 |
|
94 |
generated = generated.to(torch.float32)
|
diffrhythm/model/cfm.py
CHANGED
@@ -114,9 +114,12 @@ class CFM(nn.Module):
|
|
114 |
start_time=None,
|
115 |
latent_pred_start_frame=0,
|
116 |
latent_pred_end_frame=2048,
|
117 |
-
vocal_flag=False
|
|
|
118 |
):
|
119 |
self.eval()
|
|
|
|
|
120 |
|
121 |
if next(self.parameters()).dtype == torch.float16:
|
122 |
cond = cond.half()
|
|
|
114 |
start_time=None,
|
115 |
latent_pred_start_frame=0,
|
116 |
latent_pred_end_frame=2048,
|
117 |
+
vocal_flag=False,
|
118 |
+
odeint_method="euler"
|
119 |
):
|
120 |
self.eval()
|
121 |
+
|
122 |
+
self.odeint_kwargs = dict(method=odeint_method)
|
123 |
|
124 |
if next(self.parameters()).dtype == torch.float16:
|
125 |
cond = cond.half()
|