VisualCloze / visualcloze.py
lzyhha
when no upsampling
444c26f
raw
history blame contribute delete
20.9 kB
import random
from einops import rearrange
from diffusers.models import AutoencoderKL
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from models.sampling import prepare_modified
from models.util import load_clip, load_t5, load_flow_model
from transport import Sampler, create_transport
from data.imgproc import to_rgb_if_rgba
def center_crop(image, target_size):
width, height = image.size
new_width, new_height = target_size
left = (width - new_width) // 2
top = (height - new_height) // 2
right = left + new_width
bottom = top + new_height
return image.crop((left, top, right, bottom))
def resize_with_aspect_ratio(img, resolution, divisible=16, aspect_ratio=None):
"""Resize image while maintaining aspect ratio, ensuring area is close to resolution**2 and dimensions are divisible by 16
Args:
img: PIL Image or torch.Tensor (C,H,W)/(B,C,H,W)
resolution: target resolution
divisible: ensure output dimensions are divisible by this number
Returns:
Resized image of the same type as input
"""
# Check input type and get dimensions
is_tensor = isinstance(img, torch.Tensor)
if is_tensor:
if img.dim() == 3:
c, h, w = img.shape
batch_dim = False
else:
b, c, h, w = img.shape
batch_dim = True
else:
w, h = img.size
# Calculate new dimensions
if aspect_ratio is None:
aspect_ratio = w / h
target_area = resolution * resolution
new_h = int((target_area / aspect_ratio) ** 0.5)
new_w = int(new_h * aspect_ratio)
# Ensure divisible by divisible
new_w = max(new_w // divisible, 1) * divisible
new_h = max(new_h // divisible, 1) * divisible
# Adjust size based on input type
if is_tensor:
# Use torch interpolation method
mode = 'bilinear'
align_corners = False
if batch_dim:
return F.interpolate(img, size=(new_h, new_w),
mode=mode, align_corners=align_corners)
else:
return F.interpolate(img.unsqueeze(0), size=(new_h, new_w),
mode=mode, align_corners=align_corners).squeeze(0)
else:
# Use PIL LANCZOS resampling
return img.resize((new_w, new_h), Image.LANCZOS)
class VisualClozeModel:
def __init__(
self, model_path, model_name="flux-dev-fill-lora", max_length=512, lora_rank=256,
atol=1e-6, rtol=1e-3, solver='euler', time_shifting_factor=1,
resolution=384, precision='bf16'):
self.atol = atol
self.rtol = rtol
self.solver = solver
self.time_shifting_factor = time_shifting_factor
self.resolution = resolution
self.precision = precision
self.max_length = max_length
self.lora_rank = lora_rank
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
# Initialize model
print("Initializing model...")
self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
# Initialize VAE
print("Initializing VAE...")
self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
self.ae.requires_grad_(False)
# Initialize text encoders
print("Initializing text encoders...")
self.t5 = load_t5(self.device, max_length=self.max_length)
self.clip = load_clip(self.device)
self.model.eval().to(self.device, dtype=self.dtype)
# Load model weights
ckpt = torch.load(model_path)
self.model.load_state_dict(ckpt, strict=False)
del ckpt
# Initialize sampler
transport = create_transport(
"Linear",
"velocity",
do_shift=True,
)
self.sampler = Sampler(transport)
self.sample_fn = self.sampler.sample_ode(
sampling_method=self.solver,
num_steps=30,
atol=self.atol,
rtol=self.rtol,
reverse=False,
do_shift=True,
time_shifting_factor=self.time_shifting_factor,
)
# Image transformation
self.image_transform = transforms.Compose([
transforms.Lambda(lambda img: to_rgb_if_rgba(img)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
self.grid_h = None
self.grid_w = None
def set_grid_size(self, h, w):
"""Set grid size"""
self.grid_h = h
self.grid_w = w
@torch.no_grad
def upsampling(self, image, target_size, cfg, upsampling_steps, upsampling_noise, generator, content_prompt):
content_instruction = [
"The content of the last image in the final row is: ",
"The last image of the last row depicts: ",
"In the final row, the last image shows: ",
"The last image in the bottom row illustrates: ",
"The content of the bottom-right image is: ",
"The final image in the last row portrays: ",
"The last image of the final row displays: ",
"In the last row, the final image captures: ",
"The bottom-right corner image presents: ",
"The content of the last image in the concluding row is: ",
"In the last row, ",
"The editing instruction in the last row is: ",
]
for c in content_instruction:
if content_prompt.startswith(c):
content_prompt = content_prompt.replace(c, '')
if target_size is None:
aspect_ratio = 1
target_area = 1024 * 1024
new_h = int((target_area / aspect_ratio) ** 0.5)
new_w = int(new_h * aspect_ratio)
target_size = (new_w, new_h)
if target_size[0] * target_size[1] > 1024 * 1024:
aspect_ratio = target_size[0] / target_size[1]
target_area = 1024 * 1024
new_h = int((target_area / aspect_ratio) ** 0.5)
new_w = int(new_h * aspect_ratio)
target_size = (new_w, new_h)
image = image.resize(((target_size[0] // 16) * 16, (target_size[1] // 16) * 16))
if upsampling_noise >= 1.0:
return image
self.sample_fn = self.sampler.sample_ode(
sampling_method=self.solver,
num_steps=upsampling_steps,
atol=self.atol,
rtol=self.rtol,
reverse=False,
do_shift=False,
time_shifting_factor=1.0,
strength=upsampling_noise
)
processed_image = self.image_transform(image)
processed_image = processed_image.to(self.device, non_blocking=True)
blank = torch.zeros_like(processed_image, device=self.device, dtype=self.dtype)
mask = torch.full((1, 1, processed_image.shape[1], processed_image.shape[2]), fill_value=1, device=self.device, dtype=self.dtype)
with torch.no_grad():
latent = self.ae.encode(processed_image[None].to(self.ae.dtype)).latent_dist.sample()
blank = self.ae.encode(blank[None].to(self.ae.dtype)).latent_dist.sample()
latent = (latent - self.ae.config.shift_factor) * self.ae.config.scaling_factor
blank = (blank - self.ae.config.shift_factor) * self.ae.config.scaling_factor
latent_h, latent_w = latent.shape[2:]
mask = rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8)
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
latent = latent.to(self.dtype)
blank = blank.to(self.dtype)
latent = rearrange(latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
blank = rearrange(blank, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
img_cond = torch.cat((blank, mask), dim=-1)
# Generate noise
noise = torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=generator).to(self.dtype)
x = [[noise]]
inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[content_prompt], proportion_empty_prompts=0.0)
inp["img"] = inp["img"] * (1 - upsampling_noise) + latent * upsampling_noise
model_kwargs = dict(
txt=inp["txt"],
txt_ids=inp["txt_ids"],
txt_mask=inp["txt_mask"],
y=inp["vec"],
img_ids=inp["img_ids"],
img_mask=inp["img_mask"],
cond=img_cond,
guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype),
)
sample = self.sample_fn(
inp["img"], self.model.forward, model_kwargs
)[-1]
sample = sample[:1]
sample = rearrange(sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h // 2, w=latent_w // 2)
sample = self.ae.decode(sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0]
sample = (sample + 1.0) / 2.0
sample.clamp_(0.0, 1.0)
sample = sample[0]
output_image = to_pil_image(sample.float())
return output_image
def process_images(
self, images: list[list[Image.Image]],
prompts: list[str],
seed: int = 0,
cfg: int = 30,
steps: int = 30,
upsampling_steps: int = 10,
upsampling_noise: float = 0.4,
is_upsampling: bool =True):
"""
Processes a list of images based on the provided text prompts and settings, with optional upsampling to enhance image resolution or detail.
Parameters:
images (list[list[Image.Image]]): A collection of images arranged in a grid layout, where each row represents an in-context example or the current query.
The current query should be placed in the last row. The target image may be None in the input, while all other images should be of the PIL Image type (Image.Image).
prompts (list[str]): A list containing three prompts: the layout prompt, task prompt, and content prompt, respectively.
seed (int): A fixed integer seed to ensure reproducibility of random elements during processing.
cfg (int): The strength of Classifier-Free Diffusion Guidance, which controls the degree of influence over the generated results.
steps (int): The number of sampling steps to be performed during processing.
upsampling_steps (int): The number of denoising steps to apply when performing upsampling.
upsampling_noise (float): The noise level used as a starting point when upsampling with SDEdit. A higher value reduces noise, and setting it to 1 disables SDEdit, causing the PIL resize function to be used instead.
is_upsampling (bool, optional): A flag indicating whether upsampling should be applied using SDEdit.
Returns:
Processed images resulting from the algorithm, with optional upsampling applied based on the `is_upsampling` flag.
"""
if seed == 0:
seed = random.randint(0, 2 ** 32 - 1)
self.sample_fn = self.sampler.sample_ode(
sampling_method=self.solver,
num_steps=steps,
atol=self.atol,
rtol=self.rtol,
reverse=False,
do_shift=True,
time_shifting_factor=self.time_shifting_factor,
)
# Use class grid size
grid_h, grid_w = self.grid_h, self.grid_w
# Ensure all images are RGB mode or None
for i in range(0, grid_h):
images[i] = [img.convert("RGB") if img is not None else None for img in images[i]]
# Adjust all image sizes
resolution = self.resolution
processed_images = []
mask_position = []
target_size = None
upsampling_size = None
for i in range(grid_h):
# Find the size of the first non-empty image in this row
reference_size = None
for j in range(0, grid_w):
if images[i][j] is not None:
if i == grid_h - 1 and upsampling_size is None:
upsampling_size = images[i][j].size
resized = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None)
reference_size = resized.size
if i == grid_h - 1 and target_size is None:
target_size = reference_size
break
# Process all images in this row
for j in range(0, grid_w):
if images[i][j] is not None:
target = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None)
if target.width <= target.height:
target = target.resize((reference_size[0], int(reference_size[0] / target.width * target.height)))
target = center_crop(target, reference_size)
elif target.width > target.height:
target = target.resize((int(reference_size[1] / target.height * target.width), reference_size[1]))
target = center_crop(target, reference_size)
processed_images.append(target)
if i == grid_h - 1:
mask_position.append(0)
else:
# If this row has a reference size, use it; otherwise use default size
if reference_size:
blank = Image.new('RGB', reference_size, (0, 0, 0))
else:
blank = Image.new('RGB', (resolution, resolution), (0, 0, 0))
processed_images.append(blank)
if i == grid_h - 1:
mask_position.append(1)
else:
raise ValueError('Please provide each image in the in-context example.')
# return processed_images
if len(mask_position) > 1 and sum(mask_position) > 1:
if target_size is None:
new_w = 384
else:
new_w = target_size[0]
for i in range(len(processed_images)):
if processed_images[i] is not None:
new_h = int(processed_images[i].height * (new_w / processed_images[i].width))
new_w = int(new_w / 16) * 16
new_h = int(new_h / 16) * 16
processed_images[i] = processed_images[i].resize((new_w, new_h))
# Build grid image and mask
with torch.autocast("cuda", self.dtype):
grid_image = []
fill_mask = []
for i in range(grid_h):
row_images = [self.image_transform(img) for img in processed_images[i * grid_w: (i + 1) * grid_w]]
if i == grid_h - 1:
row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=m, device=self.device) for m in mask_position]
else:
row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=0, device=self.device) for m in mask_position]
grid_image.append(torch.cat(row_images, dim=2).to(self.device, non_blocking=True))
fill_mask.append(torch.cat(row_masks, dim=3))
# Encode condition image
with torch.no_grad():
fill_cond = [self.ae.encode(img[None].to(self.ae.dtype)).latent_dist.sample()[0] for img in grid_image]
fill_cond = [(img - self.ae.config.shift_factor) * self.ae.config.scaling_factor for img in fill_cond]
# Rearrange mask
fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) for mask in fill_mask]
fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for mask in fill_mask]
fill_cond = [img.to(self.dtype) for img in fill_cond]
fill_cond = [rearrange(img.unsqueeze(0), "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for img in fill_cond]
fill_cond = torch.cat(fill_cond, dim=1)
fill_mask = torch.cat(fill_mask, dim=1)
img_cond = torch.cat((fill_cond, fill_mask), dim=-1)
# Generate sample
noise = []
sliced_subimage = []
rng = torch.Generator(device=self.device).manual_seed(int(seed))
for sub_img in grid_image:
h, w = sub_img.shape[-2:]
sliced_subimage.append((h, w))
latent_w, latent_h = w // 8, h // 8
noise.append(torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=rng).to(self.dtype))
x = [noise]
with torch.no_grad():
inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[' '.join(prompts)], proportion_empty_prompts=0.0)
model_kwargs = dict(
txt=inp["txt"],
txt_ids=inp["txt_ids"],
txt_mask=inp["txt_mask"],
y=inp["vec"],
img_ids=inp["img_ids"],
img_mask=inp["img_mask"],
cond=img_cond,
guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype),
)
samples = self.sample_fn(
inp["img"], self.model.forward, model_kwargs
)[-1]
# Get query row
with torch.no_grad():
samples = samples[:1]
row_samples = []
start = 0
for size in sliced_subimage:
end = start + (size[0] * size[1] // 256)
latent_h = size[0] // 8
latent_w = size[1] // 8
row_sample = samples[:, start:end, :]
row_sample = rearrange(row_sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h//2, w=latent_w//2)
row_sample = self.ae.decode(row_sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0]
row_sample = (row_sample + 1.0) / 2.0
row_sample.clamp_(0.0, 1.0)
row_samples.append(row_sample[0])
start = end
# Convert all samples to PIL images
output_images = []
for row_sample in row_samples:
output_image = to_pil_image(row_sample.float())
output_images.append(output_image)
torch.cuda.empty_cache()
ret = []
ret_w = output_images[-1].width
ret_h = output_images[-1].height
row_start = (grid_h - 1) * grid_w
row_end = grid_h * grid_w
for i in range(row_start, row_end):
# when the image is masked, then output it
if mask_position[i - row_start] and is_upsampling:
cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h))
upsampled = self.upsampling(
cropped,
upsampling_size,
cfg,
upsampling_steps=upsampling_steps,
upsampling_noise=upsampling_noise,
generator=rng,
content_prompt=prompts[2])
ret.append(upsampled)
elif mask_position[i - row_start]:
cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h))
ret.append(cropped)
return ret