xizaoqu commited on
Commit
cd8c42f
·
1 Parent(s): 4170d69
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -22,6 +22,7 @@ import cv2
22
  import subprocess
23
  from PIL import Image
24
  from datetime import datetime
 
25
 
26
  ACTION_KEYS = [
27
  "inventory",
@@ -141,7 +142,7 @@ SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
141
  SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
142
 
143
  DEFAULT_IMAGE = ICE_PLAINS_IMAGE
144
- device = "cuda:0"
145
 
146
  def save_video(frames, path="output.mp4", fps=10):
147
  h, w, _ = frames[0].shape
@@ -171,11 +172,14 @@ def run(cfg: DictConfig):
171
 
172
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
173
 
174
- _ = algo.interactive(memory_frames[0],
175
- actions[0],
176
- poses[0],
177
- memory_curr_frame,
178
- device="cuda:0")
 
 
 
179
 
180
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
181
  algo.sampling_timesteps = denoising_steps
@@ -190,11 +194,11 @@ def run(cfg: DictConfig):
190
  global memory_curr_frame
191
  for i in range(len(actions)):
192
  memory_curr_frame += 1
193
- new_frame = algo.interactive(memory_frames[0],
194
  actions[i],
195
  None,
196
  memory_curr_frame,
197
- device="cuda:0")
198
 
199
  memory_frames.append(new_frame)
200
 
@@ -222,11 +226,11 @@ def run(cfg: DictConfig):
222
  memory_curr_frame = 0
223
  input_history = ""
224
 
225
- _ = algo.interactive(memory_frames[0],
226
  actions[0],
227
  poses[0],
228
  memory_curr_frame,
229
- device="cuda:0")
230
  return input_history, DEFAULT_IMAGE
231
 
232
  def on_image_click(SELECTED_IMAGE):
@@ -235,6 +239,12 @@ def run(cfg: DictConfig):
235
  reset()
236
  return SELECTED_IMAGE
237
 
 
 
 
 
 
 
238
  css = """
239
  h1 {
240
  text-align: center;
 
22
  import subprocess
23
  from PIL import Image
24
  from datetime import datetime
25
+ import spaces
26
 
27
  ACTION_KEYS = [
28
  "inventory",
 
142
  SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
143
 
144
  DEFAULT_IMAGE = ICE_PLAINS_IMAGE
145
+ device = torch.device('cuda')
146
 
147
  def save_video(frames, path="output.mp4", fps=10):
148
  h, w, _ = frames[0].shape
 
172
 
173
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
174
 
175
+ @spaces.GPU()
176
+ def run_interactive(first_frame, action, first_pose, curr_frame, device):
177
+ new_frame = algo.interactive(first_frame,
178
+ action,
179
+ first_pose,
180
+ curr_frame,
181
+ device=device)
182
+ return new_frame
183
 
184
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
185
  algo.sampling_timesteps = denoising_steps
 
194
  global memory_curr_frame
195
  for i in range(len(actions)):
196
  memory_curr_frame += 1
197
+ new_frame = run_interactive(memory_frames[0],
198
  actions[i],
199
  None,
200
  memory_curr_frame,
201
+ device=device)
202
 
203
  memory_frames.append(new_frame)
204
 
 
226
  memory_curr_frame = 0
227
  input_history = ""
228
 
229
+ _ = run_interactive(memory_frames[0],
230
  actions[0],
231
  poses[0],
232
  memory_curr_frame,
233
+ device=device)
234
  return input_history, DEFAULT_IMAGE
235
 
236
  def on_image_click(SELECTED_IMAGE):
 
239
  reset()
240
  return SELECTED_IMAGE
241
 
242
+ _ = run_interactive(memory_frames[0],
243
+ actions[0],
244
+ poses[0],
245
+ memory_curr_frame,
246
+ device)
247
+
248
  css = """
249
  h1 {
250
  text-align: center;