Character / app.py
seawolf2357's picture
Update app.py
e93244b verified
raw
history blame contribute delete
9.96 kB
import sys, os, random, numpy as np, torch
sys.path.append("../")
from PIL import Image
import spaces
import gradio as gr
from gradio.themes import Soft # ★ NEW
from huggingface_hub import hf_hub_download
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
from pipeline import InstantCharacterFluxPipeline
# ─────────────────────────────
# 1 · Runtime / device
# ─────────────────────────────
MAX_SEED = np.iinfo(np.int32).max
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# ─────────────────────────────
# 2 · Pre-trained weights
# ─────────────────────────────
ip_adapter_path = hf_hub_download("tencent/InstantCharacter",
"instantcharacter_ip-adapter.bin")
base_model = "black-forest-labs/FLUX.1-dev"
image_encoder_path = "google/siglip-so400m-patch14-384"
image_encoder2_path = "facebook/dinov2-giant"
birefnet_path = "ZhengPeng7/BiRefNet"
makoto_style_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Makoto-Shinkai",
"Makoto_Shinkai_style.safetensors")
ghibli_style_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Ghibli",
"ghibli_style.safetensors")
# ─────────────────────────────
# 3 · Pipeline init
# ─────────────────────────────
pipe = InstantCharacterFluxPipeline.from_pretrained(base_model,
torch_dtype=torch.bfloat16)
pipe.to(device)
pipe.init_adapter(
image_encoder_path=image_encoder_path,
image_encoder_2_path=image_encoder2_path,
subject_ipadapter_cfg=dict(subject_ip_adapter_path=ip_adapter_path,
nb_token=1024),
)
# ─────────────────────────────
# 4 · BiRefNet (matting)
# ─────────────────────────────
birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path,
trust_remote_code=True)
birefnet.to(device).eval()
birefnet_tf = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
# ─────────────────────────────
# 5 · Helper utils
# ─────────────────────────────
def randomize_seed_fn(seed: int, randomize: bool) -> int:
return random.randint(0, MAX_SEED) if randomize else seed
def _infer_matting(img_pil):
with torch.no_grad():
inp = birefnet_tf(img_pil).unsqueeze(0).to(device)
mask = birefnet(inp)[-1].sigmoid().cpu()[0, 0].numpy()
return (mask * 255).astype(np.uint8)
def _bbox_from_mask(mask, th=128):
ys, xs = np.where(mask >= th)
if not len(xs):
return [0, 0, mask.shape[1]-1, mask.shape[0]-1]
return [xs.min(), ys.min(), xs.max(), ys.max()]
def _pad_square(arr, pad_val=255):
h, w = arr.shape[:2]
if h == w:
return arr
diff = abs(h - w)
pad_1 = diff // 2
pad_2 = diff - pad_1
if h > w:
pad = ((0, 0), (pad_1, pad_2), (0, 0))
else:
pad = ((pad_1, pad_2), (0, 0), (0, 0))
return np.pad(arr, pad, constant_values=pad_val)
def remove_bkg(img_pil: Image.Image) -> Image.Image:
mask = _infer_matting(img_pil)
x1, y1, x2, y2 = _bbox_from_mask(mask)
mask_bin = (mask >= 128).astype(np.uint8)[..., None]
img_np = np.array(img_pil)
obj = mask_bin * img_np + (1 - mask_bin) * 255
crop = obj[y1:y2+1, x1:x2+1]
return Image.fromarray(_pad_square(crop).astype(np.uint8))
def get_example():
return [
["./assets/girl.jpg",
"A girl is playing a guitar in street", 0.9, "Makoto Shinkai style"],
["./assets/boy.jpg",
"A boy is riding a bike in snow", 0.9, "Makoto Shinkai style"],
]
@spaces.GPU
def create_image(input_image, prompt, scale,
guidance_scale, num_inference_steps,
seed, style_mode):
input_image = remove_bkg(input_image)
gen = torch.manual_seed(seed)
if style_mode is None:
imgs = pipe(prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=1024, height=1024,
subject_image=input_image, subject_scale=scale,
generator=gen).images
else:
lora_path, trigger = (
(makoto_style_path, "Makoto Shinkai style")
if style_mode == "Makoto Shinkai style"
else (ghibli_style_path, "ghibli style")
)
imgs = pipe.with_style_lora(
lora_file_path=lora_path, trigger=trigger,
prompt=prompt, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=1024, height=1024,
subject_image=input_image, subject_scale=scale,
generator=gen).images
return imgs
def run_for_examples(src, p, s, st):
return create_image(src, p, s, 3.5, 28, 123456, st)
# ─────────────────────────────
# 6 · Theme & CSS
# ─────────────────────────────
theme = Soft(primary_hue="pink",
font=[gr.themes.GoogleFont("Inter")])
css = """
body{
background:#141e30;
background:linear-gradient(135deg,#141e30,#243b55);
}
#title{
text-align:center;
font-size:2.2rem;
font-weight:700;
color:#ffffff;
padding:20px 0 6px;
}
.card{
border-radius:18px;
background:#ffffff0d;
padding:18px 22px;
backdrop-filter:blur(6px);
}
.gr-image,.gr-video{border-radius:14px}
.gr-image:hover{box-shadow:0 0 0 4px #ec4899}
footer{visibility:hidden}
"""
# ─────────────────────────────
# 7 · Gradio UI
# ─────────────────────────────
with gr.Blocks(css=css, theme=theme) as demo:
# Header
gr.Markdown("<div id='title'>InstantCharacter&nbsp;PLUS</div>")
gr.Markdown(
"<b>Official 🤗 Gradio demo of "
"<a href='https://instantcharacter.github.io/' target='_blank'>InstantCharacter</a></b>"
)
with gr.Tabs():
with gr.TabItem("Generate"):
with gr.Row(equal_height=True):
# ── Inputs
with gr.Column(elem_classes="card"):
image_pil = gr.Image(label="Source Image",
type="pil", height=380)
prompt = gr.Textbox(
label="Prompt",
value="A character is riding a bike in snow",
lines=2,
)
scale = gr.Slider(0, 1.5, 1.0, step=0.01, label="Scale")
style_mode = gr.Dropdown(
["None", "Makoto Shinkai style", "Ghibli style"],
label="Style",
value="Makoto Shinkai style",
)
with gr.Accordion("⚙️ Advanced Options", open=False):
guidance_scale = gr.Slider(
1, 7, 3.5, step=0.01, label="Guidance scale"
)
num_inference_steps = gr.Slider(
5, 50, 28, step=1, label="# Inference steps"
)
seed = gr.Number(123456, label="Seed", precision=0)
randomize_seed = gr.Checkbox(
label="Randomize seed", value=True
)
generate_btn = gr.Button(
"🚀 Generate",
variant="primary",
size="lg",
elem_classes="contrast",
)
# ── Outputs
with gr.Column(elem_classes="card"):
generated_image = gr.Gallery(
label="Generated Image",
show_label=True,
height="auto",
columns=[1],
)
# Connect button
generate_btn.click(
randomize_seed_fn,
[seed, randomize_seed],
seed,
queue=False,
).then(
create_image,
[
image_pil,
prompt,
scale,
guidance_scale,
num_inference_steps,
seed,
style_mode,
],
generated_image,
)
# Examples gallery
gr.Markdown("### 🔥 Quick Examples")
gr.Examples(
examples=get_example(),
inputs=[image_pil, prompt, scale, style_mode],
outputs=generated_image,
fn=run_for_examples,
cache_examples=True,
)
# ─────────────────────────────
# 8 · Launch
# ─────────────────────────────
if __name__ == "__main__":
demo.queue(max_size=10, api_open=False).launch()