Duskfallcrew commited on
Commit
5840753
·
verified ·
1 Parent(s): 96bf0f5

Update app.py

Browse files

Key Changes and Explanations:

load_sdxl_checkpoint (Corrected): This function now correctly extracts the state dictionaries for both text encoders (text_encoder1_state and text_encoder2_state), the VAE (vae_state), and the UNet (unet_state), using the appropriate key prefixes. It still assumes the Illustrious-xl model uses the standard SDXL prefixes for these components, which is a reasonable assumption.

build_diffusers_model (Corrected):

Loads the configurations from the reference model (or the default SDXL base) for all components: CLIPTextConfig for text_encoder, CLIPTextConfig for text_encoder_2, AutoencoderKL for vae, and UNet2DConditionModel for unet.

Creates instances of CLIPTextModel for text_encoder1 and now properly uses CLIPTextModelWithProjection for text_encoder2, and AutoencoderKL, and UNet2DConditionModel using these loaded configurations. This is crucial for getting the correct model architecture.

Loads the extracted state dictionaries into the corresponding model instances using strict=False. This handles potential key mismatches or extra keys in the Illustrious-xl checkpoint.

Sets the components to float16 and moves to the CPU.

convert_and_save_sdxl_to_diffusers: Remains mostly the same, but now correctly uses the two text encoders.

Other Functions: The rest of the code (downloading, uploading, Gradio interface) remains largely unchanged.

Testing and Further Steps

Test Thoroughly: Test this revised code with the Illustrious-xl model. It should now load the checkpoint correctly and create a Diffusers pipeline.

Verify Functionality: After converting, test the generated Diffusers model. Generate some images and compare them to the expected output from the Illustrious-xl model. This is crucial to ensure the conversion was successful and the model is working as intended.

Key Prefixes (If Still Errors): If you still encounter errors, it's possible that the Illustrious-xl model uses different key prefixes than the standard SDXL prefixes. In this case, you'll need to inspect the checkpoint's state dictionary keys directly (using a simplified loading script) to determine the correct prefixes and adjust load_sdxl_checkpoint accordingly.

Files changed (1) hide show
  1. app.py +41 -69
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import gradio as gr
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
- from transformers import CLIPTextModel, CLIPTextConfig
6
  from safetensors.torch import load_file
7
  from collections import OrderedDict
8
  import re
@@ -67,42 +67,35 @@ def create_model_repo(api, user, orgs_name, model_name, make_private=False):
67
  try:
68
  api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
69
  print(f"Model repo '{repo_id}' created.")
70
- except HfHubHTTPError: # Corrected the exception name
71
  print(f"Model repo '{repo_id}' already exists.")
72
  return repo_id
73
 
74
 
75
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
76
  def download_model(model_path_or_url):
77
- """Downloads a model, handling URLs, HF repos, and local paths, caching appropriately."""
78
  try:
79
- # 1. Check if it's a valid Hugging Face repo ID (and potentially a file within)
80
  try:
81
  validate_repo_id(model_path_or_url)
82
- # It's a valid repo ID; use hf_hub_download (it handles caching)
83
  local_path = hf_hub_download(repo_id=model_path_or_url)
84
  return local_path
85
  except HFValidationError:
86
- pass # Not a simple repo ID. Might be repo ID + filename, or a URL.
87
 
88
  # 2. Check if it's a URL
89
- if model_path_or_url.startswith("http://") or model_path_or_url.startswith(
90
- "https://"
91
- ):
92
- # It's a URL : download and put into HF cache
93
-
94
  response = requests.get(model_path_or_url, stream=True)
95
- response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
96
 
97
- # Get filename from URL, or use a hash if we can't determine it
98
  parsed_url = urlparse(model_path_or_url)
99
  filename = os.path.basename(unquote(parsed_url.path))
100
  if not filename:
101
  filename = hashlib.sha256(model_path_or_url.encode()).hexdigest()
102
 
103
- # Construct the cache path (using HF_HUB_CACHE + "downloads")
104
  cache_dir = os.path.join(HUGGINGFACE_HUB_CACHE, "downloads")
105
- os.makedirs(cache_dir, exist_ok=True) # Ensure cache directory exists
106
  local_path = os.path.join(cache_dir, filename)
107
 
108
  with open(local_path, "wb") as f:
@@ -125,7 +118,6 @@ def download_model(model_path_or_url):
125
  return local_path
126
  else:
127
  raise ValueError("Invalid input format.")
128
-
129
  except HFValidationError:
130
  raise ValueError(f"Invalid model path or URL: {model_path_or_url}")
131
 
@@ -133,15 +125,14 @@ def download_model(model_path_or_url):
133
  raise ValueError(f"Error downloading or accessing model: {e}")
134
 
135
 
 
136
  def load_sdxl_checkpoint(checkpoint_path):
137
- """Loads an SDXL checkpoint (.ckpt or .safetensors) and returns components."""
138
 
139
  if checkpoint_path.endswith(".safetensors"):
140
- state_dict = load_file(checkpoint_path, device="cpu") # Load to CPU
141
  elif checkpoint_path.endswith(".ckpt"):
142
- state_dict = torch.load(checkpoint_path, map_location="cpu")[
143
- "state_dict"
144
- ] # Load to CPU, access ["state_dict"]
145
  else:
146
  raise ValueError("Unsupported checkpoint format. Must be .safetensors or .ckpt")
147
 
@@ -152,82 +143,62 @@ def load_sdxl_checkpoint(checkpoint_path):
152
 
153
  for key, value in state_dict.items():
154
  if key.startswith("first_stage_model."): # VAE
155
- vae_state[key.replace("first_stage_model.", "")] = value.to(
156
- torch.float16
157
- ) # FP16 conversion
158
- elif key.startswith("condition_model.model.text_encoder."): # Text Encoder 1
159
- text_encoder1_state[
160
- key.replace("condition_model.model.text_encoder.", "")
161
- ] = value.to(
162
- torch.float16
163
- ) # FP16
164
- elif key.startswith(
165
- "condition_model.model.text_encoder_2."
166
- ): # Text Encoder 2
167
- text_encoder2_state[
168
- key.replace("condition_model.model.text_encoder_2.", "")
169
- ] = value.to(
170
- torch.float16
171
- ) # FP16
172
  elif key.startswith("model.diffusion_model."): # UNet
173
- unet_state[key.replace("model.diffusion_model.", "")] = value.to(
174
- torch.float16
175
- ) # FP16
176
 
177
  return text_encoder1_state, text_encoder2_state, vae_state, unet_state
178
 
179
 
 
180
  def build_diffusers_model(
181
- text_encoder1_state,
182
- text_encoder2_state,
183
- vae_state,
184
- unet_state,
185
- reference_model_path=None,
186
  ):
187
- """Builds the Diffusers pipeline components from the loaded state dicts."""
188
 
189
- # Default to SDXL base 1.0 if no reference model is provided
190
  if not reference_model_path:
191
  reference_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
192
 
193
- # 1. Text Encoders
194
  config_text_encoder1 = CLIPTextConfig.from_pretrained(
195
  reference_model_path, subfolder="text_encoder"
196
  )
197
  config_text_encoder2 = CLIPTextConfig.from_pretrained(
198
- reference_model_path, subfolder="text_encoder_2"
199
  )
 
 
200
 
 
201
  text_encoder1 = CLIPTextModel(config_text_encoder1)
202
- text_encoder2 = CLIPTextModel(config_text_encoder2)
203
- text_encoder1.load_state_dict(text_encoder1_state)
204
- text_encoder2.load_state_dict(text_encoder2_state)
205
- text_encoder1.to(torch.float16).to("cpu") # Ensure fp16 and CPU
206
- text_encoder2.to(torch.float16).to("cpu")
207
 
208
- # 2. VAE
209
- vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae")
210
- vae.load_state_dict(vae_state)
211
- vae.to(torch.float16).to("cpu")
 
212
 
213
- # 3. UNet
214
- unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet")
215
- unet.load_state_dict(unet_state)
216
  unet.to(torch.float16).to("cpu")
217
 
 
218
  return text_encoder1, text_encoder2, vae, unet
219
 
220
 
221
  def convert_and_save_sdxl_to_diffusers(
222
  checkpoint_path_or_url, output_path, reference_model_path
223
  ):
224
- """Converts an SDXL checkpoint to Diffusers format and saves it.
225
 
226
- Args:
227
- checkpoint_path_or_url: The path/URL/repo ID of the checkpoint.
228
- """
229
-
230
- # Download the model if necessary (handles URLs, repo IDs, and local paths)
231
  checkpoint_path = download_model(checkpoint_path_or_url)
232
 
233
  text_encoder1_state, text_encoder2_state, vae_state, unet_state = (
@@ -255,6 +226,7 @@ def convert_and_save_sdxl_to_diffusers(
255
  print(f"Model saved as Diffusers format: {output_path}")
256
 
257
 
 
258
  # ---------------------- UPLOAD FUNCTION ----------------------
259
  def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
260
  """Uploads a model to the Hugging Face Hub."""
@@ -362,7 +334,7 @@ with gr.Blocks(css=css) as demo:
362
  with gr.Column():
363
  output = gr.Markdown() #Output is in its own column
364
 
365
- convert_button.click( #CORRECT AREA
366
  fn=main,
367
  inputs=[
368
  model_to_load,
 
2
  import gradio as gr
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTextConfig, CLIPTokenizer
6
  from safetensors.torch import load_file
7
  from collections import OrderedDict
8
  import re
 
67
  try:
68
  api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
69
  print(f"Model repo '{repo_id}' created.")
70
+ except HfHubHTTPError:
71
  print(f"Model repo '{repo_id}' already exists.")
72
  return repo_id
73
 
74
 
75
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
76
  def download_model(model_path_or_url):
77
+ """Downloads a model, handling URLs, HF repos, and local paths."""
78
  try:
79
+ # 1. Check if it's a valid Hugging Face repo ID
80
  try:
81
  validate_repo_id(model_path_or_url)
 
82
  local_path = hf_hub_download(repo_id=model_path_or_url)
83
  return local_path
84
  except HFValidationError:
85
+ pass
86
 
87
  # 2. Check if it's a URL
88
+ if model_path_or_url.startswith("http://") or model_path_or_url.startswith("https://"):
 
 
 
 
89
  response = requests.get(model_path_or_url, stream=True)
90
+ response.raise_for_status()
91
 
 
92
  parsed_url = urlparse(model_path_or_url)
93
  filename = os.path.basename(unquote(parsed_url.path))
94
  if not filename:
95
  filename = hashlib.sha256(model_path_or_url.encode()).hexdigest()
96
 
 
97
  cache_dir = os.path.join(HUGGINGFACE_HUB_CACHE, "downloads")
98
+ os.makedirs(cache_dir, exist_ok=True)
99
  local_path = os.path.join(cache_dir, filename)
100
 
101
  with open(local_path, "wb") as f:
 
118
  return local_path
119
  else:
120
  raise ValueError("Invalid input format.")
 
121
  except HFValidationError:
122
  raise ValueError(f"Invalid model path or URL: {model_path_or_url}")
123
 
 
125
  raise ValueError(f"Error downloading or accessing model: {e}")
126
 
127
 
128
+
129
  def load_sdxl_checkpoint(checkpoint_path):
130
+ """Loads checkpoint and extracts state dicts, handling Illustrious-xl."""
131
 
132
  if checkpoint_path.endswith(".safetensors"):
133
+ state_dict = load_file(checkpoint_path, device="cpu")
134
  elif checkpoint_path.endswith(".ckpt"):
135
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
 
 
136
  else:
137
  raise ValueError("Unsupported checkpoint format. Must be .safetensors or .ckpt")
138
 
 
143
 
144
  for key, value in state_dict.items():
145
  if key.startswith("first_stage_model."): # VAE
146
+ vae_state[key.replace("first_stage_model.", "")] = value.to(torch.float16)
147
+ elif key.startswith("condition_model.model.text_encoder."): # First Text Encoder
148
+ text_encoder1_state[key.replace("condition_model.model.text_encoder.", "")] = value.to(torch.float16)
149
+ elif key.startswith("condition_model.model.text_encoder_2."): # Second Text Encoder
150
+ text_encoder2_state[key.replace("condition_model.model.text_encoder_2.", "")] = value.to(torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
151
  elif key.startswith("model.diffusion_model."): # UNet
152
+ unet_state[key.replace("model.diffusion_model.", "")] = value.to(torch.float16)
 
 
153
 
154
  return text_encoder1_state, text_encoder2_state, vae_state, unet_state
155
 
156
 
157
+
158
  def build_diffusers_model(
159
+ text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path=None
 
 
 
 
160
  ):
161
+ """Builds Diffusers components, loading state dicts with strict=False."""
162
 
 
163
  if not reference_model_path:
164
  reference_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
165
 
166
+ # Load configurations from the reference model
167
  config_text_encoder1 = CLIPTextConfig.from_pretrained(
168
  reference_model_path, subfolder="text_encoder"
169
  )
170
  config_text_encoder2 = CLIPTextConfig.from_pretrained(
171
+ reference_model_path, subfolder="text_encoder_2"
172
  )
173
+ config_vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae").config
174
+ config_unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet").config
175
 
176
+ # Create instances using the configurations
177
  text_encoder1 = CLIPTextModel(config_text_encoder1)
178
+ text_encoder2 = CLIPTextModelWithProjection(config_text_encoder2) # Use CLIPTextModelWithProjection
179
+ vae = AutoencoderKL(config=config_vae)
180
+ unet = UNet2DConditionModel(config=config_unet)
 
 
181
 
182
+ # Load state dicts with strict=False
183
+ text_encoder1.load_state_dict(text_encoder1_state, strict=False)
184
+ text_encoder2.load_state_dict(text_encoder2_state, strict=False)
185
+ vae.load_state_dict(vae_state, strict=False)
186
+ unet.load_state_dict(unet_state, strict=False)
187
 
188
+ text_encoder1.to(torch.float16).to("cpu")
189
+ text_encoder2.to(torch.float16).to("cpu")
190
+ vae.to(torch.float16).to("cpu")
191
  unet.to(torch.float16).to("cpu")
192
 
193
+
194
  return text_encoder1, text_encoder2, vae, unet
195
 
196
 
197
  def convert_and_save_sdxl_to_diffusers(
198
  checkpoint_path_or_url, output_path, reference_model_path
199
  ):
200
+ """Converts and saves the Illustrious-xl checkpoint to Diffusers format."""
201
 
 
 
 
 
 
202
  checkpoint_path = download_model(checkpoint_path_or_url)
203
 
204
  text_encoder1_state, text_encoder2_state, vae_state, unet_state = (
 
226
  print(f"Model saved as Diffusers format: {output_path}")
227
 
228
 
229
+
230
  # ---------------------- UPLOAD FUNCTION ----------------------
231
  def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
232
  """Uploads a model to the Hugging Face Hub."""
 
334
  with gr.Column():
335
  output = gr.Markdown() #Output is in its own column
336
 
337
+ convert_button.click(
338
  fn=main,
339
  inputs=[
340
  model_to_load,