hannibaking commited on
Commit
b12807b
·
verified ·
1 Parent(s): d7d3e4f

some vertical video here

Browse files
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -18,6 +18,10 @@ MODEL_REPO = "rain1011/pyramid-flow-sd3"
18
  MODEL_VARIANT = "diffusion_transformer_768p"
19
  MODEL_DTYPE = "bf16"
20
 
 
 
 
 
21
  def center_crop(image, target_width, target_height):
22
  width, height = image.size
23
  aspect_ratio_target = target_width / target_height
@@ -62,13 +66,24 @@ model = load_model()
62
 
63
  # Text-to-video generation function
64
  @spaces.GPU(duration=140)
65
- def generate_video(prompt, image=None, duration=3, guidance_scale=9, video_guidance_scale=5, frames_per_second=8, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
66
  multiplier = 1.2 if is_canonical else 3.0
67
  temp = int(duration * multiplier) + 1
68
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
69
- if(image):
70
- cropped_image = center_crop(image, 1280, 768)
71
- resized_image = cropped_image.resize((1280, 768))
 
 
 
72
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
73
  frames = model.generate_i2v(
74
  prompt=prompt,
@@ -86,14 +101,15 @@ def generate_video(prompt, image=None, duration=3, guidance_scale=9, video_guida
86
  prompt=prompt,
87
  num_inference_steps=[20, 20, 20],
88
  video_num_inference_steps=[10, 10, 10],
89
- height=768,
90
- width=1280,
91
  temp=temp,
92
  guidance_scale=guidance_scale,
93
  video_guidance_scale=video_guidance_scale,
94
  output_type="pil",
95
  save_memory=True,
96
  )
 
97
  output_path = f"{str(uuid.uuid4())}_output_video.mp4"
98
  export_to_video(frames, output_path, fps=frames_per_second)
99
  return output_path
@@ -110,6 +126,11 @@ with gr.Blocks() as demo:
110
  i2v_image = gr.Image(type="pil", label="Input Image")
111
  t2v_prompt = gr.Textbox(label="Prompt")
112
  with gr.Accordion("Advanced settings", open=False):
 
 
 
 
 
113
  t2v_duration = gr.Slider(minimum=1, maximum=3 if is_canonical else 10, value=3 if is_canonical else 5, step=1, label="Duration (seconds)", visible=not is_canonical)
114
  t2v_fps = gr.Slider(minimum=8, maximum=24, step=16, value=8 if is_canonical else 24, label="Frames per second", visible=is_canonical)
115
  t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
@@ -140,7 +161,7 @@ with gr.Blocks() as demo:
140
  )
141
  t2v_generate_btn.click(
142
  generate_video,
143
- inputs=[t2v_prompt, i2v_image, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale, t2v_fps],
144
  outputs=t2v_output
145
  )
146
 
 
18
  MODEL_VARIANT = "diffusion_transformer_768p"
19
  MODEL_DTYPE = "bf16"
20
 
21
+ # Define resolution presets
22
+ LANDSCAPE_RESOLUTION = {"width": 1280, "height": 768}
23
+ PORTRAIT_RESOLUTION = {"width": 768, "height": 1280}
24
+
25
  def center_crop(image, target_width, target_height):
26
  width, height = image.size
27
  aspect_ratio_target = target_width / target_height
 
66
 
67
  # Text-to-video generation function
68
  @spaces.GPU(duration=140)
69
+ def generate_video(prompt, image=None, orientation="landscape", duration=3, guidance_scale=9, video_guidance_scale=5, frames_per_second=8, progress=gr.Progress(track_tqdm=True)):
70
+ # Set width and height based on orientation
71
+ if orientation == "landscape":
72
+ width = LANDSCAPE_RESOLUTION["width"]
73
+ height = LANDSCAPE_RESOLUTION["height"]
74
+ else: # portrait
75
+ width = PORTRAIT_RESOLUTION["width"]
76
+ height = PORTRAIT_RESOLUTION["height"]
77
+
78
  multiplier = 1.2 if is_canonical else 3.0
79
  temp = int(duration * multiplier) + 1
80
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
81
+
82
+ if image:
83
+ # Process the input image according to the selected orientation
84
+ cropped_image = center_crop(image, width, height)
85
+ resized_image = cropped_image.resize((width, height))
86
+
87
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
88
  frames = model.generate_i2v(
89
  prompt=prompt,
 
101
  prompt=prompt,
102
  num_inference_steps=[20, 20, 20],
103
  video_num_inference_steps=[10, 10, 10],
104
+ height=height,
105
+ width=width,
106
  temp=temp,
107
  guidance_scale=guidance_scale,
108
  video_guidance_scale=video_guidance_scale,
109
  output_type="pil",
110
  save_memory=True,
111
  )
112
+
113
  output_path = f"{str(uuid.uuid4())}_output_video.mp4"
114
  export_to_video(frames, output_path, fps=frames_per_second)
115
  return output_path
 
126
  i2v_image = gr.Image(type="pil", label="Input Image")
127
  t2v_prompt = gr.Textbox(label="Prompt")
128
  with gr.Accordion("Advanced settings", open=False):
129
+ t2v_orientation = gr.Radio(
130
+ choices=["landscape", "portrait"],
131
+ value="landscape",
132
+ label="Video Orientation"
133
+ )
134
  t2v_duration = gr.Slider(minimum=1, maximum=3 if is_canonical else 10, value=3 if is_canonical else 5, step=1, label="Duration (seconds)", visible=not is_canonical)
135
  t2v_fps = gr.Slider(minimum=8, maximum=24, step=16, value=8 if is_canonical else 24, label="Frames per second", visible=is_canonical)
136
  t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
 
161
  )
162
  t2v_generate_btn.click(
163
  generate_video,
164
+ inputs=[t2v_prompt, i2v_image, t2v_orientation, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale, t2v_fps],
165
  outputs=t2v_output
166
  )
167