Files changed (1) hide show
  1. app.py +53 -119
app.py CHANGED
@@ -2,77 +2,21 @@ import os
2
  import gradio as gr
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
- from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTextConfig, CLIPTokenizer
6
  from safetensors.torch import load_file
7
  from collections import OrderedDict
8
- import re
9
- import json
10
  import requests
11
- import subprocess
12
  from urllib.parse import urlparse, unquote
13
  from pathlib import Path
14
  import hashlib
15
- from datetime import datetime
16
- from typing import Dict, List, Optional
17
  from huggingface_hub import login, HfApi, hf_hub_download
18
  from huggingface_hub.utils import validate_repo_id, HFValidationError
19
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
20
  from huggingface_hub.utils import HfHubHTTPError
21
 
22
-
23
- # ---------------------- DEPENDENCIES ----------------------
24
- def install_dependencies_gradio():
25
- """Installs the necessary dependencies."""
26
- try:
27
- subprocess.run(
28
- [
29
- "pip",
30
- "install",
31
- "-U",
32
- "torch",
33
- "diffusers",
34
- "transformers",
35
- "accelerate",
36
- "safetensors",
37
- "huggingface_hub",
38
- "xformers",
39
- ]
40
- )
41
- print("Dependencies installed successfully.")
42
- except Exception as e:
43
- print(f"Error installing dependencies: {e}")
44
-
45
-
46
  # ---------------------- UTILITY FUNCTIONS ----------------------
 
47
 
48
-
49
- def increment_filename(filename):
50
- """Increments the filename to avoid overwriting existing files."""
51
- base, ext = os.path.splitext(filename)
52
- counter = 1
53
- while os.path.exists(filename):
54
- filename = f"{base}({counter}){ext}"
55
- counter += 1
56
- return filename
57
-
58
-
59
- # ---------------------- UPLOAD FUNCTION ----------------------
60
- def create_model_repo(api, user, orgs_name, model_name, make_private=False):
61
- """Creates a Hugging Face model repository."""
62
- repo_id = (
63
- f"{orgs_name}/{model_name.strip()}"
64
- if orgs_name
65
- else f"{user['name']}/{model_name.strip()}"
66
- )
67
- try:
68
- api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
69
- print(f"Model repo '{repo_id}' created.")
70
- except HfHubHTTPError:
71
- print(f"Model repo '{repo_id}' already exists.")
72
- return repo_id
73
-
74
-
75
- # ---------------------- MODEL LOADING AND CONVERSION ----------------------
76
  def download_model(model_path_or_url):
77
  """Downloads a model, handling URLs, HF repos, and local paths."""
78
  try:
@@ -125,10 +69,21 @@ def download_model(model_path_or_url):
125
  raise ValueError(f"Error downloading or accessing model: {e}")
126
 
127
 
128
-
 
 
 
 
 
 
 
 
 
 
 
 
129
  def load_sdxl_checkpoint(checkpoint_path):
130
- """Loads checkpoint and extracts state dicts, handling Illustrious-xl."""
131
-
132
  if checkpoint_path.endswith(".safetensors"):
133
  state_dict = load_file(checkpoint_path, device="cpu")
134
  elif checkpoint_path.endswith(".ckpt"):
@@ -142,44 +97,34 @@ def load_sdxl_checkpoint(checkpoint_path):
142
  unet_state = OrderedDict()
143
 
144
  for key, value in state_dict.items():
145
- if key.startswith("first_stage_model."): # VAE
146
  vae_state[key.replace("first_stage_model.", "")] = value.to(torch.float16)
147
- elif key.startswith("condition_model.model.text_encoder."): # First Text Encoder
148
  text_encoder1_state[key.replace("condition_model.model.text_encoder.", "")] = value.to(torch.float16)
149
- elif key.startswith("condition_model.model.text_encoder_2."): # Second Text Encoder
150
  text_encoder2_state[key.replace("condition_model.model.text_encoder_2.", "")] = value.to(torch.float16)
151
- elif key.startswith("model.diffusion_model."): # UNet
152
  unet_state[key.replace("model.diffusion_model.", "")] = value.to(torch.float16)
153
 
154
  return text_encoder1_state, text_encoder2_state, vae_state, unet_state
155
 
156
 
157
 
158
- def build_diffusers_model(
159
- text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path=None
160
- ):
161
  """Builds Diffusers components, loading state dicts with strict=False."""
162
-
163
  if not reference_model_path:
164
  reference_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
165
 
166
- # Load configurations from the reference model
167
- config_text_encoder1 = CLIPTextConfig.from_pretrained(
168
- reference_model_path, subfolder="text_encoder"
169
- )
170
- config_text_encoder2 = CLIPTextConfig.from_pretrained(
171
- reference_model_path, subfolder="text_encoder_2"
172
- )
173
  config_vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae").config
174
  config_unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet").config
175
 
176
- # Create instances using the configurations
177
  text_encoder1 = CLIPTextModel(config_text_encoder1)
178
- text_encoder2 = CLIPTextModelWithProjection(config_text_encoder2) # Use CLIPTextModelWithProjection
179
  vae = AutoencoderKL(config=config_vae)
180
  unet = UNet2DConditionModel(config=config_unet)
181
 
182
- # Load state dicts with strict=False
183
  text_encoder1.load_state_dict(text_encoder1_state, strict=False)
184
  text_encoder2.load_state_dict(text_encoder2_state, strict=False)
185
  vae.load_state_dict(vae_state, strict=False)
@@ -190,29 +135,16 @@ def build_diffusers_model(
190
  vae.to(torch.float16).to("cpu")
191
  unet.to(torch.float16).to("cpu")
192
 
193
-
194
  return text_encoder1, text_encoder2, vae, unet
195
 
196
-
197
- def convert_and_save_sdxl_to_diffusers(
198
- checkpoint_path_or_url, output_path, reference_model_path
199
- ):
200
- """Converts and saves the Illustrious-xl checkpoint to Diffusers format."""
201
-
202
  checkpoint_path = download_model(checkpoint_path_or_url)
203
-
204
- text_encoder1_state, text_encoder2_state, vae_state, unet_state = (
205
- load_sdxl_checkpoint(checkpoint_path)
206
- )
207
  text_encoder1, text_encoder2, vae, unet = build_diffusers_model(
208
- text_encoder1_state,
209
- text_encoder2_state,
210
- vae_state,
211
- unet_state,
212
- reference_model_path,
213
  )
214
 
215
- # Load tokenizer and scheduler from the reference model
216
  pipeline = StableDiffusionXLPipeline.from_pretrained(
217
  reference_model_path,
218
  text_encoder=text_encoder1,
@@ -225,9 +157,6 @@ def convert_and_save_sdxl_to_diffusers(
225
  pipeline.save_pretrained(output_path)
226
  print(f"Model saved as Diffusers format: {output_path}")
227
 
228
-
229
-
230
- # ---------------------- UPLOAD FUNCTION ----------------------
231
  def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
232
  """Uploads a model to the Hugging Face Hub."""
233
  login(token=hf_token, add_to_git_credential=True)
@@ -237,8 +166,8 @@ def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_priv
237
  api.upload_folder(folder_path=model_path, repo_id=model_repo)
238
  print(f"Model uploaded to: https://huggingface.co/{model_repo}")
239
 
 
240
 
241
- # ---------------------- GRADIO INTERFACE ----------------------
242
  def main(
243
  model_to_load,
244
  reference_model,
@@ -248,7 +177,16 @@ def main(
248
  model_name,
249
  make_private,
250
  ):
251
- """Main function: SDXL checkpoint to Diffusers, always fp16."""
 
 
 
 
 
 
 
 
 
252
 
253
  try:
254
  convert_and_save_sdxl_to_diffusers(
@@ -257,10 +195,15 @@ def main(
257
  upload_to_huggingface(
258
  output_path, hf_token, orgs_name, model_name, make_private
259
  )
260
- return "Conversion and upload completed successfully!"
 
 
261
  except Exception as e:
262
- return f"An error occurred: {e}" # Return the error message
 
 
263
 
 
264
 
265
  css = """
266
  #main-container {
@@ -271,7 +214,7 @@ css = """
271
  color: #333;
272
  }
273
  #convert-button {
274
- margin-top: 1em; /* Adds some space above the button */
275
  }
276
  """
277
 
@@ -306,7 +249,6 @@ with gr.Blocks(css=css) as demo:
306
 
307
  with gr.Row():
308
  with gr.Column():
309
-
310
  model_to_load = gr.Textbox(
311
  label="SDXL Checkpoint (Path, URL, or HF Repo)",
312
  placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
@@ -315,25 +257,17 @@ with gr.Blocks(css=css) as demo:
315
  label="Reference Diffusers Model (Optional)",
316
  placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)",
317
  )
318
- output_path = gr.Textbox(
319
- label="Output Path (Diffusers Format)", value="output"
320
- )
321
- hf_token = gr.Textbox(
322
- label="Hugging Face Token", placeholder="Your Hugging Face write token", type="password"
323
- )
324
- orgs_name = gr.Textbox(
325
- label="Organization Name (Optional)", placeholder="Your organization name"
326
- )
327
- model_name = gr.Textbox(
328
- label="Model Name", placeholder="The name of your model on Hugging Face"
329
- )
330
  make_private = gr.Checkbox(label="Make Repository Private", value=False)
331
-
332
  convert_button = gr.Button("Convert and Upload")
333
 
334
- with gr.Column(variant="panel"): # Use variant="panel"
335
  output = gr.Markdown(container=False)
336
 
 
337
  convert_button.click(
338
  fn=main,
339
  inputs=[
 
2
  import gradio as gr
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTextConfig
6
  from safetensors.torch import load_file
7
  from collections import OrderedDict
 
 
8
  import requests
 
9
  from urllib.parse import urlparse, unquote
10
  from pathlib import Path
11
  import hashlib
 
 
12
  from huggingface_hub import login, HfApi, hf_hub_download
13
  from huggingface_hub.utils import validate_repo_id, HFValidationError
14
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
15
  from huggingface_hub.utils import HfHubHTTPError
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ---------------------- UTILITY FUNCTIONS ----------------------
18
+ # (download_model, create_model_repo, etc. - All unchanged, but included for completeness)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def download_model(model_path_or_url):
21
  """Downloads a model, handling URLs, HF repos, and local paths."""
22
  try:
 
69
  raise ValueError(f"Error downloading or accessing model: {e}")
70
 
71
 
72
+ def create_model_repo(api, user, orgs_name, model_name, make_private=False):
73
+ """Creates a Hugging Face model repository."""
74
+ repo_id = (
75
+ f"{orgs_name}/{model_name.strip()}"
76
+ if orgs_name
77
+ else f"{user['name']}/{model_name.strip()}"
78
+ )
79
+ try:
80
+ api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
81
+ print(f"Model repo '{repo_id}' created.")
82
+ except HfHubHTTPError:
83
+ print(f"Model repo '{repo_id}' already exists.")
84
+ return repo_id
85
  def load_sdxl_checkpoint(checkpoint_path):
86
+ """Loads checkpoint and extracts state dicts."""
 
87
  if checkpoint_path.endswith(".safetensors"):
88
  state_dict = load_file(checkpoint_path, device="cpu")
89
  elif checkpoint_path.endswith(".ckpt"):
 
97
  unet_state = OrderedDict()
98
 
99
  for key, value in state_dict.items():
100
+ if key.startswith("first_stage_model."):
101
  vae_state[key.replace("first_stage_model.", "")] = value.to(torch.float16)
102
+ elif key.startswith("condition_model.model.text_encoder."):
103
  text_encoder1_state[key.replace("condition_model.model.text_encoder.", "")] = value.to(torch.float16)
104
+ elif key.startswith("condition_model.model.text_encoder_2."):
105
  text_encoder2_state[key.replace("condition_model.model.text_encoder_2.", "")] = value.to(torch.float16)
106
+ elif key.startswith("model.diffusion_model."):
107
  unet_state[key.replace("model.diffusion_model.", "")] = value.to(torch.float16)
108
 
109
  return text_encoder1_state, text_encoder2_state, vae_state, unet_state
110
 
111
 
112
 
113
+ def build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path=None):
 
 
114
  """Builds Diffusers components, loading state dicts with strict=False."""
 
115
  if not reference_model_path:
116
  reference_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
117
 
118
+ config_text_encoder1 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder")
119
+ config_text_encoder2 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder_2")
 
 
 
 
 
120
  config_vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae").config
121
  config_unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet").config
122
 
 
123
  text_encoder1 = CLIPTextModel(config_text_encoder1)
124
+ text_encoder2 = CLIPTextModelWithProjection(config_text_encoder2) # Correct class
125
  vae = AutoencoderKL(config=config_vae)
126
  unet = UNet2DConditionModel(config=config_unet)
127
 
 
128
  text_encoder1.load_state_dict(text_encoder1_state, strict=False)
129
  text_encoder2.load_state_dict(text_encoder2_state, strict=False)
130
  vae.load_state_dict(vae_state, strict=False)
 
135
  vae.to(torch.float16).to("cpu")
136
  unet.to(torch.float16).to("cpu")
137
 
 
138
  return text_encoder1, text_encoder2, vae, unet
139
 
140
+ def convert_and_save_sdxl_to_diffusers(checkpoint_path_or_url, output_path, reference_model_path):
141
+ """Converts and saves the checkpoint to Diffusers format."""
 
 
 
 
142
  checkpoint_path = download_model(checkpoint_path_or_url)
143
+ text_encoder1_state, text_encoder2_state, vae_state, unet_state = load_sdxl_checkpoint(checkpoint_path)
 
 
 
144
  text_encoder1, text_encoder2, vae, unet = build_diffusers_model(
145
+ text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path
 
 
 
 
146
  )
147
 
 
148
  pipeline = StableDiffusionXLPipeline.from_pretrained(
149
  reference_model_path,
150
  text_encoder=text_encoder1,
 
157
  pipeline.save_pretrained(output_path)
158
  print(f"Model saved as Diffusers format: {output_path}")
159
 
 
 
 
160
  def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
161
  """Uploads a model to the Hugging Face Hub."""
162
  login(token=hf_token, add_to_git_credential=True)
 
166
  api.upload_folder(folder_path=model_path, repo_id=model_repo)
167
  print(f"Model uploaded to: https://huggingface.co/{model_repo}")
168
 
169
+ # ---------------------- MAIN FUNCTION (with Debugging Prints) ----------------------
170
 
 
171
  def main(
172
  model_to_load,
173
  reference_model,
 
177
  model_name,
178
  make_private,
179
  ):
180
+ """Main function: SDXL checkpoint to Diffusers, with debugging prints."""
181
+
182
+ print("---- Main Function Called ----") # Debug Print
183
+ print(f" model_to_load: {model_to_load}") # Debug Print
184
+ print(f" reference_model: {reference_model}") # Debug Print
185
+ print(f" output_path: {output_path}") # Debug Print
186
+ print(f" hf_token: {hf_token}") # Debug Print
187
+ print(f" orgs_name: {orgs_name}") # Debug Print
188
+ print(f" model_name: {model_name}") # Debug Print
189
+ print(f" make_private: {make_private}") # Debug Print
190
 
191
  try:
192
  convert_and_save_sdxl_to_diffusers(
 
195
  upload_to_huggingface(
196
  output_path, hf_token, orgs_name, model_name, make_private
197
  )
198
+ result = "Conversion and upload completed successfully!"
199
+ print(f"---- Main Function Successful: {result} ----") # Debug Print
200
+ return result
201
  except Exception as e:
202
+ error_message = f"An error occurred: {e}"
203
+ print(f"---- Main Function Error: {error_message} ----") # Debug Print
204
+ return error_message
205
 
206
+ # ---------------------- GRADIO INTERFACE (Corrected Button Placement) ----------------------
207
 
208
  css = """
209
  #main-container {
 
214
  color: #333;
215
  }
216
  #convert-button {
217
+ margin-top: 1em;
218
  }
219
  """
220
 
 
249
 
250
  with gr.Row():
251
  with gr.Column():
 
252
  model_to_load = gr.Textbox(
253
  label="SDXL Checkpoint (Path, URL, or HF Repo)",
254
  placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
 
257
  label="Reference Diffusers Model (Optional)",
258
  placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)",
259
  )
260
+ output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="output")
261
+ hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token", type="password")
262
+ orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
263
+ model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
 
 
 
 
 
 
 
 
264
  make_private = gr.Checkbox(label="Make Repository Private", value=False)
 
265
  convert_button = gr.Button("Convert and Upload")
266
 
267
+ with gr.Column(variant="panel"):
268
  output = gr.Markdown(container=False)
269
 
270
+ # --- CORRECT BUTTON CLICK PLACEMENT ---
271
  convert_button.click(
272
  fn=main,
273
  inputs=[