|
|
|
|
|
import sys, os, random, numpy as np, torch |
|
sys.path.append("../") |
|
|
|
from PIL import Image |
|
import spaces |
|
import gradio as gr |
|
from gradio.themes import Soft |
|
from huggingface_hub import hf_hub_download |
|
from transformers import AutoModelForImageSegmentation |
|
from torchvision import transforms |
|
|
|
from pipeline import InstantCharacterFluxPipeline |
|
|
|
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
|
ip_adapter_path = hf_hub_download("tencent/InstantCharacter", |
|
"instantcharacter_ip-adapter.bin") |
|
base_model = "black-forest-labs/FLUX.1-dev" |
|
image_encoder_path = "google/siglip-so400m-patch14-384" |
|
image_encoder2_path = "facebook/dinov2-giant" |
|
birefnet_path = "ZhengPeng7/BiRefNet" |
|
makoto_style_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Makoto-Shinkai", |
|
"Makoto_Shinkai_style.safetensors") |
|
ghibli_style_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Ghibli", |
|
"ghibli_style.safetensors") |
|
|
|
|
|
|
|
|
|
pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, |
|
torch_dtype=torch.bfloat16) |
|
pipe.to(device) |
|
pipe.init_adapter( |
|
image_encoder_path=image_encoder_path, |
|
image_encoder_2_path=image_encoder2_path, |
|
subject_ipadapter_cfg=dict(subject_ip_adapter_path=ip_adapter_path, |
|
nb_token=1024), |
|
) |
|
|
|
|
|
|
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, |
|
trust_remote_code=True) |
|
birefnet.to(device).eval() |
|
birefnet_tf = transforms.Compose([ |
|
transforms.Resize((1024, 1024)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], |
|
[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
|
|
|
|
def randomize_seed_fn(seed: int, randomize: bool) -> int: |
|
return random.randint(0, MAX_SEED) if randomize else seed |
|
|
|
def _infer_matting(img_pil): |
|
with torch.no_grad(): |
|
inp = birefnet_tf(img_pil).unsqueeze(0).to(device) |
|
mask = birefnet(inp)[-1].sigmoid().cpu()[0, 0].numpy() |
|
return (mask * 255).astype(np.uint8) |
|
|
|
def _bbox_from_mask(mask, th=128): |
|
ys, xs = np.where(mask >= th) |
|
if not len(xs): |
|
return [0, 0, mask.shape[1]-1, mask.shape[0]-1] |
|
return [xs.min(), ys.min(), xs.max(), ys.max()] |
|
|
|
def _pad_square(arr, pad_val=255): |
|
h, w = arr.shape[:2] |
|
if h == w: |
|
return arr |
|
diff = abs(h - w) |
|
pad_1 = diff // 2 |
|
pad_2 = diff - pad_1 |
|
if h > w: |
|
pad = ((0, 0), (pad_1, pad_2), (0, 0)) |
|
else: |
|
pad = ((pad_1, pad_2), (0, 0), (0, 0)) |
|
return np.pad(arr, pad, constant_values=pad_val) |
|
|
|
def remove_bkg(img_pil: Image.Image) -> Image.Image: |
|
mask = _infer_matting(img_pil) |
|
x1, y1, x2, y2 = _bbox_from_mask(mask) |
|
mask_bin = (mask >= 128).astype(np.uint8)[..., None] |
|
img_np = np.array(img_pil) |
|
obj = mask_bin * img_np + (1 - mask_bin) * 255 |
|
crop = obj[y1:y2+1, x1:x2+1] |
|
return Image.fromarray(_pad_square(crop).astype(np.uint8)) |
|
|
|
def get_example(): |
|
return [ |
|
["./assets/girl.jpg", |
|
"A girl is playing a guitar in street", 0.9, "Makoto Shinkai style"], |
|
["./assets/boy.jpg", |
|
"A boy is riding a bike in snow", 0.9, "Makoto Shinkai style"], |
|
] |
|
|
|
@spaces.GPU |
|
def create_image(input_image, prompt, scale, |
|
guidance_scale, num_inference_steps, |
|
seed, style_mode): |
|
input_image = remove_bkg(input_image) |
|
gen = torch.manual_seed(seed) |
|
|
|
if style_mode is None: |
|
imgs = pipe(prompt=prompt, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
width=1024, height=1024, |
|
subject_image=input_image, subject_scale=scale, |
|
generator=gen).images |
|
else: |
|
lora_path, trigger = ( |
|
(makoto_style_path, "Makoto Shinkai style") |
|
if style_mode == "Makoto Shinkai style" |
|
else (ghibli_style_path, "ghibli style") |
|
) |
|
imgs = pipe.with_style_lora( |
|
lora_file_path=lora_path, trigger=trigger, |
|
prompt=prompt, num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
width=1024, height=1024, |
|
subject_image=input_image, subject_scale=scale, |
|
generator=gen).images |
|
return imgs |
|
|
|
def run_for_examples(src, p, s, st): |
|
return create_image(src, p, s, 3.5, 28, 123456, st) |
|
|
|
|
|
|
|
|
|
theme = Soft(primary_hue="pink", |
|
font=[gr.themes.GoogleFont("Inter")]) |
|
|
|
css = """ |
|
body{ |
|
background:#141e30; |
|
background:linear-gradient(135deg,#141e30,#243b55); |
|
} |
|
#title{ |
|
text-align:center; |
|
font-size:2.2rem; |
|
font-weight:700; |
|
color:#ffffff; |
|
padding:20px 0 6px; |
|
} |
|
.card{ |
|
border-radius:18px; |
|
background:#ffffff0d; |
|
padding:18px 22px; |
|
backdrop-filter:blur(6px); |
|
} |
|
.gr-image,.gr-video{border-radius:14px} |
|
.gr-image:hover{box-shadow:0 0 0 4px #ec4899} |
|
footer{visibility:hidden} |
|
""" |
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=css, theme=theme) as demo: |
|
|
|
gr.Markdown("<div id='title'>InstantCharacter PLUS</div>") |
|
gr.Markdown( |
|
"<b>Official 🤗 Gradio demo of " |
|
"<a href='https://instantcharacter.github.io/' target='_blank'>InstantCharacter</a></b>" |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Generate"): |
|
with gr.Row(equal_height=True): |
|
|
|
with gr.Column(elem_classes="card"): |
|
image_pil = gr.Image(label="Source Image", |
|
type="pil", height=380) |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
value="A character is riding a bike in snow", |
|
lines=2, |
|
) |
|
scale = gr.Slider(0, 1.5, 1.0, step=0.01, label="Scale") |
|
style_mode = gr.Dropdown( |
|
["None", "Makoto Shinkai style", "Ghibli style"], |
|
label="Style", |
|
value="Makoto Shinkai style", |
|
) |
|
|
|
with gr.Accordion("⚙️ Advanced Options", open=False): |
|
guidance_scale = gr.Slider( |
|
1, 7, 3.5, step=0.01, label="Guidance scale" |
|
) |
|
num_inference_steps = gr.Slider( |
|
5, 50, 28, step=1, label="# Inference steps" |
|
) |
|
seed = gr.Number(123456, label="Seed", precision=0) |
|
randomize_seed = gr.Checkbox( |
|
label="Randomize seed", value=True |
|
) |
|
|
|
generate_btn = gr.Button( |
|
"🚀 Generate", |
|
variant="primary", |
|
size="lg", |
|
elem_classes="contrast", |
|
) |
|
|
|
|
|
with gr.Column(elem_classes="card"): |
|
generated_image = gr.Gallery( |
|
label="Generated Image", |
|
show_label=True, |
|
height="auto", |
|
columns=[1], |
|
) |
|
|
|
|
|
generate_btn.click( |
|
randomize_seed_fn, |
|
[seed, randomize_seed], |
|
seed, |
|
queue=False, |
|
).then( |
|
create_image, |
|
[ |
|
image_pil, |
|
prompt, |
|
scale, |
|
guidance_scale, |
|
num_inference_steps, |
|
seed, |
|
style_mode, |
|
], |
|
generated_image, |
|
) |
|
|
|
|
|
gr.Markdown("### 🔥 Quick Examples") |
|
gr.Examples( |
|
examples=get_example(), |
|
inputs=[image_pil, prompt, scale, style_mode], |
|
outputs=generated_image, |
|
fn=run_for_examples, |
|
cache_examples=True, |
|
) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=10, api_open=False).launch() |
|
|