Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import List | |
import torch | |
from diffusers import FlowMatchEulerDiscreteScheduler | |
from PIL import Image | |
from torchvision import transforms | |
from lbm.models.embedders import ( | |
ConditionerWrapper, | |
LatentsConcatEmbedder, | |
LatentsConcatEmbedderConfig, | |
) | |
from lbm.models.lbm import LBMConfig, LBMModel | |
from lbm.models.unets import DiffusersUNet2DCondWrapper | |
from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig | |
def get_model_from_config( | |
backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0", | |
vae_num_channels: int = 4, | |
unet_input_channels: int = 4, | |
timestep_sampling: str = "log_normal", | |
selected_timesteps: List[float] = None, | |
prob: List[float] = None, | |
conditioning_images_keys: List[str] = [], | |
conditioning_masks_keys: List[str] = ["mask"], | |
source_key: str = "source_image", | |
target_key: str = "source_image_paste", | |
bridge_noise_sigma: float = 0.0, | |
): | |
conditioners = [] | |
denoiser = DiffusersUNet2DCondWrapper( | |
in_channels=unet_input_channels, # Add downsampled_image | |
out_channels=vae_num_channels, | |
center_input_sample=False, | |
flip_sin_to_cos=True, | |
freq_shift=0, | |
down_block_types=[ | |
"DownBlock2D", | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
], | |
mid_block_type="UNetMidBlock2DCrossAttn", | |
up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], | |
only_cross_attention=False, | |
block_out_channels=[320, 640, 1280], | |
layers_per_block=2, | |
downsample_padding=1, | |
mid_block_scale_factor=1, | |
dropout=0.0, | |
act_fn="silu", | |
norm_num_groups=32, | |
norm_eps=1e-05, | |
cross_attention_dim=[320, 640, 1280], | |
transformer_layers_per_block=[1, 2, 10], | |
reverse_transformer_layers_per_block=None, | |
encoder_hid_dim=None, | |
encoder_hid_dim_type=None, | |
attention_head_dim=[5, 10, 20], | |
num_attention_heads=None, | |
dual_cross_attention=False, | |
use_linear_projection=True, | |
class_embed_type=None, | |
addition_embed_type=None, | |
addition_time_embed_dim=None, | |
num_class_embeds=None, | |
upcast_attention=None, | |
resnet_time_scale_shift="default", | |
resnet_skip_time_act=False, | |
resnet_out_scale_factor=1.0, | |
time_embedding_type="positional", | |
time_embedding_dim=None, | |
time_embedding_act_fn=None, | |
timestep_post_act=None, | |
time_cond_proj_dim=None, | |
conv_in_kernel=3, | |
conv_out_kernel=3, | |
projection_class_embeddings_input_dim=None, | |
attention_type="default", | |
class_embeddings_concat=False, | |
mid_block_only_cross_attention=None, | |
cross_attention_norm=None, | |
addition_embed_type_num_heads=64, | |
).to(torch.bfloat16) | |
if conditioning_images_keys != [] or conditioning_masks_keys != []: | |
latents_concat_embedder_config = LatentsConcatEmbedderConfig( | |
image_keys=conditioning_images_keys, | |
mask_keys=conditioning_masks_keys, | |
) | |
latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config) | |
latent_concat_embedder.freeze() | |
conditioners.append(latent_concat_embedder) | |
# Wrap conditioners and set to device | |
conditioner = ConditionerWrapper( | |
conditioners=conditioners, | |
) | |
## VAE ## | |
# Get VAE model | |
vae_config = AutoencoderKLDiffusersConfig( | |
version=backbone_signature, | |
subfolder="vae", | |
tiling_size=(128, 128), | |
) | |
vae = AutoencoderKLDiffusers(vae_config).to(torch.bfloat16) | |
vae.freeze() | |
vae.to(torch.bfloat16) | |
## Diffusion Model ## | |
# Get diffusion model | |
config = LBMConfig( | |
source_key=source_key, | |
target_key=target_key, | |
timestep_sampling=timestep_sampling, | |
selected_timesteps=selected_timesteps, | |
prob=prob, | |
bridge_noise_sigma=bridge_noise_sigma, | |
) | |
sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
backbone_signature, | |
subfolder="scheduler", | |
) | |
model = LBMModel( | |
config, | |
denoiser=denoiser, | |
sampling_noise_scheduler=sampling_noise_scheduler, | |
vae=vae, | |
conditioner=conditioner, | |
).to(torch.bfloat16) | |
return model | |
def extract_object(birefnet, img): | |
# Data settings | |
image_size = (1024, 1024) | |
transform_image = transforms.Compose( | |
[ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
image = img | |
input_images = transform_image(image).unsqueeze(0).cuda() | |
# Prediction | |
with torch.no_grad(): | |
preds = birefnet(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image.size) | |
image = Image.composite(image, Image.new("RGB", image.size, (127, 127, 127)), mask) | |
return image, mask | |
def resize_and_center_crop(image, target_width, target_height): | |
original_width, original_height = image.size | |
scale_factor = max(target_width / original_width, target_height / original_height) | |
resized_width = int(round(original_width * scale_factor)) | |
resized_height = int(round(original_height * scale_factor)) | |
resized_image = image.resize((resized_width, resized_height), Image.LANCZOS) | |
left = (resized_width - target_width) / 2 | |
top = (resized_height - target_height) / 2 | |
right = (resized_width + target_width) / 2 | |
bottom = (resized_height + target_height) / 2 | |
cropped_image = resized_image.crop((left, top, right, bottom)) | |
return cropped_image | |