xllsd-alpha0 / create-xllsd.py
ppbrown's picture
Upload create-xllsd.py
1ef66da verified
#!/bin/env python
"""
This script pulls in the various standarad components for
an SD1.5 architecture model from DIFFERENT places.
It takes original SD1.5 base, but then pulls in the improved VAE
from SDXL, and then an improved "Long CLIP" text encoder from elsewhere
It then writes out a combined model in "diffusers" format.
That is more or less the contents of
https://huggingface.co/opendiffusionai/xllsd-alpha0
Feel free to use it for your own model creation experiments.
Of note to most people is that it pulls in the "float32" versions.
However, people with smaller hardware may wish to specify
torch_dtype=torch.float16
if they are just going to train in float16 or bf16 anyway
"""
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline, AutoencoderKL
import torch
print("Loading main model")
# Load SD1.5 diffusers model in FP32
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float32)
print("Loading LONG CLIP")
# Load LongCLIP text encoder in FP32
clip_path = "zer0int/LongCLIP-GmP-ViT-L-14"
new_text_encoder = CLIPTextModel.from_pretrained(clip_path)
new_tokenizer = CLIPTokenizer.from_pretrained(clip_path)
print("Loading SDXL VAE")
new_vae = AutoencoderKL.from_pretrained(
"stabilityai/sdxl-vae",
torch_dtype=torch.float32
)
# Replace the text encoder and tokenizer
pipe.text_encoder = new_text_encoder
pipe.tokenizer = new_tokenizer
pipe.vae = new_vae
# Move the pipeline to GPU to confirm everything loads
print("Combining...")
pipe.to("cuda")
###############################################################
# Save the updated pipeline in Diffusers format
# IF you are going to convert to a single .safetensors, set safe_serialization False
# But if you are going to use in place, then set it to True
###############################################################
outname = "XLLsd_df"
pipe.save_pretrained(outname, safe_serialization=True)
print(f"Replaced text encoder and saved pipeline to {outname}")