import random from einops import rearrange from diffusers.models import AutoencoderKL from PIL import Image import torch import torch.nn.functional as F from torchvision import transforms from torchvision.transforms.functional import to_pil_image from models.sampling import prepare_modified from models.util import load_clip, load_t5, load_flow_model from transport import Sampler, create_transport from data.imgproc import to_rgb_if_rgba def center_crop(image, target_size): width, height = image.size new_width, new_height = target_size left = (width - new_width) // 2 top = (height - new_height) // 2 right = left + new_width bottom = top + new_height return image.crop((left, top, right, bottom)) def resize_with_aspect_ratio(img, resolution, divisible=16, aspect_ratio=None): """Resize image while maintaining aspect ratio, ensuring area is close to resolution**2 and dimensions are divisible by 16 Args: img: PIL Image or torch.Tensor (C,H,W)/(B,C,H,W) resolution: target resolution divisible: ensure output dimensions are divisible by this number Returns: Resized image of the same type as input """ # Check input type and get dimensions is_tensor = isinstance(img, torch.Tensor) if is_tensor: if img.dim() == 3: c, h, w = img.shape batch_dim = False else: b, c, h, w = img.shape batch_dim = True else: w, h = img.size # Calculate new dimensions if aspect_ratio is None: aspect_ratio = w / h target_area = resolution * resolution new_h = int((target_area / aspect_ratio) ** 0.5) new_w = int(new_h * aspect_ratio) # Ensure divisible by divisible new_w = max(new_w // divisible, 1) * divisible new_h = max(new_h // divisible, 1) * divisible # Adjust size based on input type if is_tensor: # Use torch interpolation method mode = 'bilinear' align_corners = False if batch_dim: return F.interpolate(img, size=(new_h, new_w), mode=mode, align_corners=align_corners) else: return F.interpolate(img.unsqueeze(0), size=(new_h, new_w), mode=mode, align_corners=align_corners).squeeze(0) else: # Use PIL LANCZOS resampling return img.resize((new_w, new_h), Image.LANCZOS) class VisualClozeModel: def __init__( self, model_path, model_name="flux-dev-fill-lora", max_length=512, lora_rank=256, atol=1e-6, rtol=1e-3, solver='euler', time_shifting_factor=1, resolution=384, precision='bf16'): self.atol = atol self.rtol = rtol self.solver = solver self.time_shifting_factor = time_shifting_factor self.resolution = resolution self.precision = precision self.max_length = max_length self.lora_rank = lora_rank self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision] # Initialize model print("Initializing model...") self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank) # Initialize VAE print("Initializing VAE...") self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device) self.ae.requires_grad_(False) # Initialize text encoders print("Initializing text encoders...") self.t5 = load_t5(self.device, max_length=self.max_length) self.clip = load_clip(self.device) self.model.eval().to(self.device, dtype=self.dtype) # Load model weights ckpt = torch.load(model_path) self.model.load_state_dict(ckpt, strict=False) del ckpt # Initialize sampler transport = create_transport( "Linear", "velocity", do_shift=True, ) self.sampler = Sampler(transport) self.sample_fn = self.sampler.sample_ode( sampling_method=self.solver, num_steps=30, atol=self.atol, rtol=self.rtol, reverse=False, do_shift=True, time_shifting_factor=self.time_shifting_factor, ) # Image transformation self.image_transform = transforms.Compose([ transforms.Lambda(lambda img: to_rgb_if_rgba(img)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) self.grid_h = None self.grid_w = None def set_grid_size(self, h, w): """Set grid size""" self.grid_h = h self.grid_w = w @torch.no_grad def upsampling(self, image, target_size, cfg, upsampling_steps, upsampling_noise, generator, content_prompt): content_instruction = [ "The content of the last image in the final row is: ", "The last image of the last row depicts: ", "In the final row, the last image shows: ", "The last image in the bottom row illustrates: ", "The content of the bottom-right image is: ", "The final image in the last row portrays: ", "The last image of the final row displays: ", "In the last row, the final image captures: ", "The bottom-right corner image presents: ", "The content of the last image in the concluding row is: ", "In the last row, ", "The editing instruction in the last row is: ", ] for c in content_instruction: if content_prompt.startswith(c): content_prompt = content_prompt.replace(c, '') if target_size is None: aspect_ratio = 1 target_area = 1024 * 1024 new_h = int((target_area / aspect_ratio) ** 0.5) new_w = int(new_h * aspect_ratio) target_size = (new_w, new_h) if target_size[0] * target_size[1] > 1024 * 1024: aspect_ratio = target_size[0] / target_size[1] target_area = 1024 * 1024 new_h = int((target_area / aspect_ratio) ** 0.5) new_w = int(new_h * aspect_ratio) target_size = (new_w, new_h) image = image.resize(((target_size[0] // 16) * 16, (target_size[1] // 16) * 16)) if upsampling_noise >= 1.0: return image self.sample_fn = self.sampler.sample_ode( sampling_method=self.solver, num_steps=upsampling_steps, atol=self.atol, rtol=self.rtol, reverse=False, do_shift=False, time_shifting_factor=1.0, strength=upsampling_noise ) processed_image = self.image_transform(image) processed_image = processed_image.to(self.device, non_blocking=True) blank = torch.zeros_like(processed_image, device=self.device, dtype=self.dtype) mask = torch.full((1, 1, processed_image.shape[1], processed_image.shape[2]), fill_value=1, device=self.device, dtype=self.dtype) with torch.no_grad(): latent = self.ae.encode(processed_image[None].to(self.ae.dtype)).latent_dist.sample() blank = self.ae.encode(blank[None].to(self.ae.dtype)).latent_dist.sample() latent = (latent - self.ae.config.shift_factor) * self.ae.config.scaling_factor blank = (blank - self.ae.config.shift_factor) * self.ae.config.scaling_factor latent_h, latent_w = latent.shape[2:] mask = rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) latent = latent.to(self.dtype) blank = blank.to(self.dtype) latent = rearrange(latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) blank = rearrange(blank, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) img_cond = torch.cat((blank, mask), dim=-1) # Generate noise noise = torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=generator).to(self.dtype) x = [[noise]] inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[content_prompt], proportion_empty_prompts=0.0) inp["img"] = inp["img"] * (1 - upsampling_noise) + latent * upsampling_noise model_kwargs = dict( txt=inp["txt"], txt_ids=inp["txt_ids"], txt_mask=inp["txt_mask"], y=inp["vec"], img_ids=inp["img_ids"], img_mask=inp["img_mask"], cond=img_cond, guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype), ) sample = self.sample_fn( inp["img"], self.model.forward, model_kwargs )[-1] sample = sample[:1] sample = rearrange(sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h // 2, w=latent_w // 2) sample = self.ae.decode(sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0] sample = (sample + 1.0) / 2.0 sample.clamp_(0.0, 1.0) sample = sample[0] output_image = to_pil_image(sample.float()) return output_image def process_images( self, images: list[list[Image.Image]], prompts: list[str], seed: int = 0, cfg: int = 30, steps: int = 30, upsampling_steps: int = 10, upsampling_noise: float = 0.4, is_upsampling: bool =True): """ Processes a list of images based on the provided text prompts and settings, with optional upsampling to enhance image resolution or detail. Parameters: images (list[list[Image.Image]]): A collection of images arranged in a grid layout, where each row represents an in-context example or the current query. The current query should be placed in the last row. The target image may be None in the input, while all other images should be of the PIL Image type (Image.Image). prompts (list[str]): A list containing three prompts: the layout prompt, task prompt, and content prompt, respectively. seed (int): A fixed integer seed to ensure reproducibility of random elements during processing. cfg (int): The strength of Classifier-Free Diffusion Guidance, which controls the degree of influence over the generated results. steps (int): The number of sampling steps to be performed during processing. upsampling_steps (int): The number of denoising steps to apply when performing upsampling. upsampling_noise (float): The noise level used as a starting point when upsampling with SDEdit. A higher value reduces noise, and setting it to 1 disables SDEdit, causing the PIL resize function to be used instead. is_upsampling (bool, optional): A flag indicating whether upsampling should be applied using SDEdit. Returns: Processed images resulting from the algorithm, with optional upsampling applied based on the `is_upsampling` flag. """ if seed == 0: seed = random.randint(0, 2 ** 32 - 1) self.sample_fn = self.sampler.sample_ode( sampling_method=self.solver, num_steps=steps, atol=self.atol, rtol=self.rtol, reverse=False, do_shift=True, time_shifting_factor=self.time_shifting_factor, ) # Use class grid size grid_h, grid_w = self.grid_h, self.grid_w # Ensure all images are RGB mode or None for i in range(0, grid_h): images[i] = [img.convert("RGB") if img is not None else None for img in images[i]] # Adjust all image sizes resolution = self.resolution processed_images = [] mask_position = [] target_size = None upsampling_size = None for i in range(grid_h): # Find the size of the first non-empty image in this row reference_size = None for j in range(0, grid_w): if images[i][j] is not None: if i == grid_h - 1 and upsampling_size is None: upsampling_size = images[i][j].size resized = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None) reference_size = resized.size if i == grid_h - 1 and target_size is None: target_size = reference_size break # Process all images in this row for j in range(0, grid_w): if images[i][j] is not None: target = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None) if target.width <= target.height: target = target.resize((reference_size[0], int(reference_size[0] / target.width * target.height))) target = center_crop(target, reference_size) elif target.width > target.height: target = target.resize((int(reference_size[1] / target.height * target.width), reference_size[1])) target = center_crop(target, reference_size) processed_images.append(target) if i == grid_h - 1: mask_position.append(0) else: # If this row has a reference size, use it; otherwise use default size if reference_size: blank = Image.new('RGB', reference_size, (0, 0, 0)) else: blank = Image.new('RGB', (resolution, resolution), (0, 0, 0)) processed_images.append(blank) if i == grid_h - 1: mask_position.append(1) else: raise ValueError('Please provide each image in the in-context example.') # return processed_images if len(mask_position) > 1 and sum(mask_position) > 1: if target_size is None: new_w = 384 else: new_w = target_size[0] for i in range(len(processed_images)): if processed_images[i] is not None: new_h = int(processed_images[i].height * (new_w / processed_images[i].width)) new_w = int(new_w / 16) * 16 new_h = int(new_h / 16) * 16 processed_images[i] = processed_images[i].resize((new_w, new_h)) # Build grid image and mask with torch.autocast("cuda", self.dtype): grid_image = [] fill_mask = [] for i in range(grid_h): row_images = [self.image_transform(img) for img in processed_images[i * grid_w: (i + 1) * grid_w]] if i == grid_h - 1: row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=m, device=self.device) for m in mask_position] else: row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=0, device=self.device) for m in mask_position] grid_image.append(torch.cat(row_images, dim=2).to(self.device, non_blocking=True)) fill_mask.append(torch.cat(row_masks, dim=3)) # Encode condition image with torch.no_grad(): fill_cond = [self.ae.encode(img[None].to(self.ae.dtype)).latent_dist.sample()[0] for img in grid_image] fill_cond = [(img - self.ae.config.shift_factor) * self.ae.config.scaling_factor for img in fill_cond] # Rearrange mask fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) for mask in fill_mask] fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for mask in fill_mask] fill_cond = [img.to(self.dtype) for img in fill_cond] fill_cond = [rearrange(img.unsqueeze(0), "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for img in fill_cond] fill_cond = torch.cat(fill_cond, dim=1) fill_mask = torch.cat(fill_mask, dim=1) img_cond = torch.cat((fill_cond, fill_mask), dim=-1) # Generate sample noise = [] sliced_subimage = [] rng = torch.Generator(device=self.device).manual_seed(int(seed)) for sub_img in grid_image: h, w = sub_img.shape[-2:] sliced_subimage.append((h, w)) latent_w, latent_h = w // 8, h // 8 noise.append(torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=rng).to(self.dtype)) x = [noise] with torch.no_grad(): inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[' '.join(prompts)], proportion_empty_prompts=0.0) model_kwargs = dict( txt=inp["txt"], txt_ids=inp["txt_ids"], txt_mask=inp["txt_mask"], y=inp["vec"], img_ids=inp["img_ids"], img_mask=inp["img_mask"], cond=img_cond, guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype), ) samples = self.sample_fn( inp["img"], self.model.forward, model_kwargs )[-1] # Get query row with torch.no_grad(): samples = samples[:1] row_samples = [] start = 0 for size in sliced_subimage: end = start + (size[0] * size[1] // 256) latent_h = size[0] // 8 latent_w = size[1] // 8 row_sample = samples[:, start:end, :] row_sample = rearrange(row_sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h//2, w=latent_w//2) row_sample = self.ae.decode(row_sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0] row_sample = (row_sample + 1.0) / 2.0 row_sample.clamp_(0.0, 1.0) row_samples.append(row_sample[0]) start = end # Convert all samples to PIL images output_images = [] for row_sample in row_samples: output_image = to_pil_image(row_sample.float()) output_images.append(output_image) torch.cuda.empty_cache() ret = [] ret_w = output_images[-1].width ret_h = output_images[-1].height row_start = (grid_h - 1) * grid_w row_end = grid_h * grid_w for i in range(row_start, row_end): # when the image is masked, then output it if mask_position[i - row_start] and is_upsampling: cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h)) upsampled = self.upsampling( cropped, upsampling_size, cfg, upsampling_steps=upsampling_steps, upsampling_noise=upsampling_noise, generator=rng, content_prompt=prompts[2]) ret.append(upsampled) elif mask_position[i - row_start]: cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h)) ret.append(cropped) return ret