hma / app.py
liruiw's picture
fix
dded309
raw
history blame contribute delete
2.47 kB
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator
RES = 512
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
genie = GenieSimulator(
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type='stmar',
backbone_ckpt='data/mar_ckpt/langtable',
prompt_horizon=3,
action_stride=1,
domain='language_table',
)
prompt_image = np.tile(
np.array(image), (genie.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
prompt_action = np.zeros(
(genie.prompt_horizon - 1, genie.action_stride, 2)
).astype(np.float32)
genie.set_initial_state((prompt_image, prompt_action))
image = genie.reset()
image = cv2.resize(image, (RES, RES))
image = Image.fromarray(image)
# Example model: takes a direction and returns a random image
def model(direction: str, genie=genie):
if direction == 'right':
action = np.array([0, 0.05])
elif direction == 'left':
action = np.array([0, -0.05])
elif direction == 'down':
action = np.array([0.05, 0])
elif direction == 'up':
action = np.array([-0.05, 0])
else:
raise ValueError(f"Invalid direction: {direction}")
next_image = genie.step(action)['pred_next_frame']
next_image = cv2.resize(next_image, (RES, RES))
return Image.fromarray(next_image)
# Gradio function to handle user input
@spaces.GPU
def handle_input(direction):
print(f"User clicked: {direction}")
new_image = model(direction) # Get a new image from the model
return new_image
if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
image_display = gr.Image(value=image, type="pil", label="Generated Image")
with gr.Row():
up = gr.Button("↑ Up")
with gr.Row():
left = gr.Button("← Left")
down = gr.Button("↓ Down")
right = gr.Button("β†’ Right")
# Define button interactions
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
demo.launch()