|
import streamlit as st |
|
import torch |
|
import os |
|
import uuid |
|
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler |
|
from diffusers.utils import export_to_video |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
import time |
|
|
|
|
|
bases = { |
|
"Cartoon": "frankjoshua/toonyou_beta6", |
|
"Realistic": "emilianJR/epiCRealism", |
|
"3d": "Lykon/DreamShaper", |
|
"Anime": "Yntec/mistoonAnime2" |
|
} |
|
step_loaded = None |
|
base_loaded = "Realistic" |
|
motion_loaded = None |
|
|
|
|
|
device = "cpu" |
|
dtype = torch.float32 |
|
|
|
|
|
st.set_page_config(page_title="Instant⚡ Text to Video", layout="centered") |
|
|
|
|
|
@st.cache_resource |
|
def init_pipeline(): |
|
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device) |
|
pipe.scheduler = EulerDiscreteScheduler.from_config( |
|
pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear" |
|
) |
|
|
|
pipe.safety_checker = None |
|
return pipe |
|
|
|
pipe = init_pipeline() |
|
|
|
|
|
def generate_image(prompt, base="Realistic", motion="", step=1): |
|
global step_loaded, base_loaded, motion_loaded |
|
|
|
step = int(step) |
|
st.write(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}") |
|
|
|
|
|
if step_loaded != step: |
|
repo = "ByteDance/AnimateDiff-Lightning" |
|
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" |
|
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False) |
|
step_loaded = step |
|
|
|
|
|
if base_loaded != base: |
|
pipe.unet.load_state_dict( |
|
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), |
|
strict=False |
|
) |
|
base_loaded = base |
|
|
|
|
|
if motion_loaded != motion: |
|
pipe.unload_lora_weights() |
|
if motion != "": |
|
pipe.load_lora_weights(motion, adapter_name="motion") |
|
pipe.set_adapters(["motion"], [0.7]) |
|
motion_loaded = motion |
|
|
|
|
|
progress_bar = st.progress(0) |
|
def progress_callback(i, t, z): |
|
progress_bar.progress((i + 1) / step) |
|
|
|
|
|
with torch.no_grad(): |
|
output = pipe( |
|
prompt=prompt, |
|
guidance_scale=1.2, |
|
num_inference_steps=step, |
|
callback=progress_callback, |
|
callback_steps=1 |
|
) |
|
|
|
|
|
name = str(uuid.uuid4()).replace("-", "") |
|
path = f"/tmp/{name}.mp4" |
|
export_to_video(output.frames[0], path, fps=10) |
|
return path |
|
|
|
|
|
st.title("Instant⚡ Text to Video") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f4f4f9; color: #333;} |
|
.stApp {max-width: 800px; margin: auto; padding: 20px; background: #fff; box-shadow: 0px 0px 20px rgba(0,0,0,0.1); border-radius: 10px;} |
|
.stButton>button {width: 100%; background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer;} |
|
.stButton>button:hover {background-color: #45a049;} |
|
.stVideo {margin-top: 20px;} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
prompt = st.text_input("Prompt", placeholder="Enter text to generate video...") |
|
base = st.selectbox("Base model", ["Cartoon", "Realistic", "3d", "Anime"], index=1) |
|
motion = st.selectbox( |
|
"Motion", |
|
[ |
|
("Default", ""), |
|
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"), |
|
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"), |
|
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"), |
|
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"), |
|
("Pan left", "guoyww/animatediff-motion-lora-pan-left"), |
|
("Pan right", "guoyww/animatediff-motion-lora-pan-right"), |
|
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"), |
|
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"), |
|
], |
|
format_func=lambda x: x[0], |
|
index=1 |
|
)[1] |
|
step = st.selectbox("Inference steps", [1, 2, 4, 8], index=0) |
|
|
|
|
|
if st.button("Generate Video"): |
|
if prompt: |
|
with st.spinner("Generating video..."): |
|
start_time = time.time() |
|
video_path = generate_image(prompt, base, motion, step) |
|
end_time = time.time() |
|
st.success(f"Video generated in {end_time - start_time:.2f} seconds!") |
|
st.video(video_path) |
|
else: |
|
st.error("Please enter a prompt!") |
|
|
|
|
|
st.subheader("Examples") |
|
examples = [ |
|
"Focus: Eiffel Tower (Animate: Clouds moving)", |
|
"Focus: Trees In forest (Animate: Lion running)", |
|
"Focus: Astronaut in Space", |
|
"Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)", |
|
"Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)", |
|
"Focus: Panda in Forest (Animate: Drinking Tea)", |
|
"Focus: Kids Playing (Season: Winter)", |
|
"Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)" |
|
] |
|
for example in examples: |
|
if st.button(example, key=example): |
|
with st.spinner("Generating video..."): |
|
video_path = generate_image(example, base, motion, step) |
|
st.video(video_path) |