neochar / utils.py
Liang Qu
Initial commit.
f2de1ca
raw
history blame contribute delete
4.12 kB
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image as PILImage, ImageDraw, ImageFont
from imwatermark import WatermarkEncoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
from transformers import MT5Tokenizer, MT5EncoderModel
from typing import List, Optional, Tuple, Union
# Determine device and torch dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load MT5 tokenizer and encoder (can be replaced with private model + token if needed)
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small", use_safetensors=True)
encoder_model = MT5EncoderModel.from_pretrained("google/mt5-small", use_safetensors=True).to(device=device, dtype=torch_dtype)
encoder_model.eval()
class QPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
def add_watermark(self, img: PILImage.Image) -> PILImage.Image:
# Resize image to 256, as 128 is too small for watermark
img = img.resize((256, 256), resample=PILImage.BICUBIC)
watermark_str = os.getenv("WATERMARK_URL", "hf.co/lqume/new-hanzi")
encoder = WatermarkEncoder()
encoder.set_watermark('bytes', watermark_str.encode('utf-8'))
# Convert PIL image to NumPy array
img_np = np.asarray(img.convert("RGB")) # ensure 3-channel RGB
watermarked_np = encoder.encode(img_np, 'dwtDct')
# Convert back to PIL
return PILImage.fromarray(watermarked_np)
@torch.no_grad()
def __call__(
self,
texts: List[str],
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple[List[PILImage.Image]]]:
batch_size = len(texts)
# Tokenize input text
tokenized = tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=48
)
input_ids = tokenized["input_ids"].to(device=device, dtype=torch.long)
attention_mask = tokenized["attention_mask"].to(device=device, dtype=torch.long)
# Encode to latent space
encoded = encoder_model.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Prepare noise tensor
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=torch_dtype)
# Run denoising loop
self.scheduler.set_timesteps(num_inference_steps)
for timestep in self.progress_bar(self.scheduler.timesteps):
noise_pred = self.unet(
image,
timestep,
encoder_hidden_states=encoded.last_hidden_state,
encoder_attention_mask=attention_mask.bool(),
return_dict=False
)[0]
image = self.scheduler.step(noise_pred, timestep, image, generator=generator, return_dict=False)[0]
# Final image post-processing
image = image.clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
image = [self.add_watermark(img) for img in image]
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)