LBM_relighting / utils.py
clementchadebec's picture
Upload 3 files
a88bb44 verified
import os
from typing import List
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from PIL import Image
from torchvision import transforms
from lbm.models.embedders import (
ConditionerWrapper,
LatentsConcatEmbedder,
LatentsConcatEmbedderConfig,
)
from lbm.models.lbm import LBMConfig, LBMModel
from lbm.models.unets import DiffusersUNet2DCondWrapper
from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
def get_model_from_config(
backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0",
vae_num_channels: int = 4,
unet_input_channels: int = 4,
timestep_sampling: str = "log_normal",
selected_timesteps: List[float] = None,
prob: List[float] = None,
conditioning_images_keys: List[str] = [],
conditioning_masks_keys: List[str] = ["mask"],
source_key: str = "source_image",
target_key: str = "source_image_paste",
bridge_noise_sigma: float = 0.0,
):
conditioners = []
denoiser = DiffusersUNet2DCondWrapper(
in_channels=unet_input_channels, # Add downsampled_image
out_channels=vae_num_channels,
center_input_sample=False,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=[
"DownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
],
mid_block_type="UNetMidBlock2DCrossAttn",
up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
only_cross_attention=False,
block_out_channels=[320, 640, 1280],
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
dropout=0.0,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-05,
cross_attention_dim=[320, 640, 1280],
transformer_layers_per_block=[1, 2, 10],
reverse_transformer_layers_per_block=None,
encoder_hid_dim=None,
encoder_hid_dim_type=None,
attention_head_dim=[5, 10, 20],
num_attention_heads=None,
dual_cross_attention=False,
use_linear_projection=True,
class_embed_type=None,
addition_embed_type=None,
addition_time_embed_dim=None,
num_class_embeds=None,
upcast_attention=None,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
time_embedding_type="positional",
time_embedding_dim=None,
time_embedding_act_fn=None,
timestep_post_act=None,
time_cond_proj_dim=None,
conv_in_kernel=3,
conv_out_kernel=3,
projection_class_embeddings_input_dim=None,
attention_type="default",
class_embeddings_concat=False,
mid_block_only_cross_attention=None,
cross_attention_norm=None,
addition_embed_type_num_heads=64,
).to(torch.bfloat16)
if conditioning_images_keys != [] or conditioning_masks_keys != []:
latents_concat_embedder_config = LatentsConcatEmbedderConfig(
image_keys=conditioning_images_keys,
mask_keys=conditioning_masks_keys,
)
latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config)
latent_concat_embedder.freeze()
conditioners.append(latent_concat_embedder)
# Wrap conditioners and set to device
conditioner = ConditionerWrapper(
conditioners=conditioners,
)
## VAE ##
# Get VAE model
vae_config = AutoencoderKLDiffusersConfig(
version=backbone_signature,
subfolder="vae",
tiling_size=(128, 128),
)
vae = AutoencoderKLDiffusers(vae_config).to(torch.bfloat16)
vae.freeze()
vae.to(torch.bfloat16)
## Diffusion Model ##
# Get diffusion model
config = LBMConfig(
source_key=source_key,
target_key=target_key,
timestep_sampling=timestep_sampling,
selected_timesteps=selected_timesteps,
prob=prob,
bridge_noise_sigma=bridge_noise_sigma,
)
sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
backbone_signature,
subfolder="scheduler",
)
model = LBMModel(
config,
denoiser=denoiser,
sampling_noise_scheduler=sampling_noise_scheduler,
vae=vae,
conditioner=conditioner,
).to(torch.bfloat16)
return model
def extract_object(birefnet, img):
# Data settings
image_size = (1024, 1024)
transform_image = transforms.Compose(
[
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = img
input_images = transform_image(image).unsqueeze(0).cuda()
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
image = Image.composite(image, Image.new("RGB", image.size, (127, 127, 127)), mask)
return image, mask
def resize_and_center_crop(image, target_width, target_height):
original_width, original_height = image.size
scale_factor = max(target_width / original_width, target_height / original_height)
resized_width = int(round(original_width * scale_factor))
resized_height = int(round(original_height * scale_factor))
resized_image = image.resize((resized_width, resized_height), Image.LANCZOS)
left = (resized_width - target_width) / 2
top = (resized_height - target_height) / 2
right = (resized_width + target_width) / 2
bottom = (resized_height + target_height) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return cropped_image