Files changed (1) hide show
  1. app.py +107 -63
app.py CHANGED
@@ -23,6 +23,32 @@ from typing import Dict, List, Optional
23
  from huggingface_hub import login, HfApi
24
  from types import SimpleNamespace
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # ---------------------- UTILITY FUNCTIONS ----------------------
27
 
28
  def is_valid_url(url):
@@ -30,26 +56,34 @@ def is_valid_url(url):
30
  try:
31
  result = urlparse(url)
32
  return all([result.scheme, result.netloc])
33
- except:
 
34
  return False
35
 
36
  def get_filename(url):
37
- response = requests.get(url, stream=True)
38
- response.raise_for_status()
 
 
39
 
40
- if 'content-disposition' in response.headers:
41
- content_disposition = response.headers['content-disposition']
42
- filename = re.findall('filename="?([^"]+)"?', content_disposition)[0]
43
- else:
44
- url_path = urlparse(url).path
45
- filename = unquote(os.path.basename(url_path))
46
 
47
- return filename
 
 
 
48
 
49
  def get_supported_extensions():
 
50
  return tuple([".ckpt", ".safetensors", ".pt", ".pth"])
51
 
52
  def download_model(url, dst, output_widget):
 
53
  filename = get_filename(url)
54
  filepath = os.path.join(dst, filename)
55
  try:
@@ -60,32 +94,34 @@ def download_model(url, dst, output_widget):
60
  if "/blob/" in url:
61
  url = url.replace("/blob/", "/resolve/")
62
  subprocess.run(["aria2c","-x 16",url,"-d",dst,"-o",filename])
63
- with output_widget:
64
- return filepath
65
  except Exception as e:
66
- with output_widget:
67
- return None
68
 
69
  def determine_load_checkpoint(model_to_load):
70
  """Determines if the model to load is a checkpoint, Diffusers model, or URL."""
71
- if is_valid_url(model_to_load) and (model_to_load.endswith(get_supported_extensions())):
72
- return True
73
- elif model_to_load.endswith(get_supported_extensions()):
74
- return True
75
- elif os.path.isdir(model_to_load):
76
- required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
77
- if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
78
- return False
 
 
 
79
  return None # handle this case as required
80
 
81
  def create_model_repo(api, user, orgs_name, model_name, make_private=False):
82
  """Creates a Hugging Face model repository if it doesn't exist."""
83
- if orgs_name == "":
84
- repo_id = user["name"] + "/" + model_name.strip()
85
- else:
86
- repo_id = orgs_name + "/" + model_name.strip()
87
-
88
  try:
 
 
 
 
 
89
  validate_repo_id(repo_id)
90
  api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
91
  print(f"Model repo '{repo_id}' didn't exist, creating repo")
@@ -98,46 +134,54 @@ def create_model_repo(api, user, orgs_name, model_name, make_private=False):
98
 
99
  def is_diffusers_model(model_path):
100
  """Checks if a given path is a valid Diffusers model directory."""
101
- required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
102
- return required_folders.issubset(set(os.listdir(model_path))) and os.path.isfile(os.path.join(model_path, "model_index.json"))
 
 
 
 
103
 
104
  # ---------------------- MODEL UTIL (From library.sdxl_model_util) ----------------------
105
 
106
  def load_models_from_sdxl_checkpoint(sdxl_base_id, checkpoint_path, device):
107
  """Loads SDXL model components from a checkpoint file."""
108
- text_encoder1 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder").to(device)
109
- text_encoder2 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder_2").to(device)
110
- vae = AutoencoderKL.from_pretrained(sdxl_base_id, subfolder="vae").to(device)
111
- unet = UNet2DConditionModel.from_pretrained(sdxl_base_id, subfolder="unet").to(device)
112
- unet = unet
113
-
114
- ckpt_state_dict = torch.load(checkpoint_path, map_location=device)
115
-
116
- o = OrderedDict()
117
- for key in list(ckpt_state_dict.keys()):
118
- o[key.replace("module.", "")] = ckpt_state_dict[key]
119
- del ckpt_state_dict
120
-
121
- print("Applying weights to text encoder 1:")
122
- text_encoder1.load_state_dict({
123
- '.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.cond_stage_model.model.transformer")
124
- }, strict=False)
125
- print("Applying weights to text encoder 2:")
126
- text_encoder2.load_state_dict({
127
- '.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("cond_stage_model.model.transformer")
128
- }, strict=False)
129
- print("Applying weights to VAE:")
130
- vae.load_state_dict({
131
- '.'.join(key.split('.')[2:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.model")
132
- }, strict=False)
133
- print("Applying weights to UNet:")
134
- unet.load_state_dict({
135
- key: o[key] for key in list(o.keys()) if key.startswith("model.diffusion_model")
136
- }, strict=False)
137
-
138
- logit_scale = None #Not used here!
139
- global_step = None #Not used here!
140
- return text_encoder1, text_encoder2, vae, unet, logit_scale, global_step
 
 
 
 
141
 
142
  def save_stable_diffusion_checkpoint(save_path, text_encoder1, text_encoder2, unet, epoch, global_step, ckpt_info, vae, logit_scale, save_dtype):
143
  """Saves the stable diffusion checkpoint."""
@@ -665,7 +709,7 @@ def main(model_to_load, save_precision_as, epoch, global_step, reference_model,
665
 
666
  # Create tempdir, will only be there for the function
667
  with tempfile.TemporaryDirectory() as output_path:
668
- conversion_output = convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, output)
669
 
670
  upload_output = upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
671
 
 
23
  from huggingface_hub import login, HfApi
24
  from types import SimpleNamespace
25
 
26
+ # Remove unused imports
27
+ # import os
28
+ # import gradio as gr
29
+ # import torch
30
+ # from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
31
+ # from transformers import CLIPTextModel, CLIPTextConfig
32
+ # from safetensors.torch import load_file
33
+ # from collections import OrderedDict
34
+ # import re
35
+ # import json
36
+ # import gdown
37
+ # import requests
38
+ # import subprocess
39
+ # from urllib.parse import urlparse, unquote
40
+ # from pathlib import Path
41
+ # import tempfile
42
+ # from tqdm import tqdm
43
+ # import psutil
44
+ # import math
45
+ # import shutil
46
+ # import hashlib
47
+ # from datetime import datetime
48
+ # from typing import Dict, List, Optional
49
+ # from huggingface_hub import login, HfApi
50
+ # from types import SimpleNamespace
51
+
52
  # ---------------------- UTILITY FUNCTIONS ----------------------
53
 
54
  def is_valid_url(url):
 
56
  try:
57
  result = urlparse(url)
58
  return all([result.scheme, result.netloc])
59
+ except Exception as e:
60
+ print(f"Error checking URL validity: {e}")
61
  return False
62
 
63
  def get_filename(url):
64
+ """Extracts the filename from a URL."""
65
+ try:
66
+ response = requests.get(url, stream=True)
67
+ response.raise_for_status()
68
 
69
+ if 'content-disposition' in response.headers:
70
+ content_disposition = response.headers['content-disposition']
71
+ filename = re.findall('filename="?([^";]+)"?', content_disposition)[0]
72
+ else:
73
+ url_path = urlparse(url).path
74
+ filename = unquote(os.path.basename(url_path))
75
 
76
+ return filename
77
+ except Exception as e:
78
+ print(f"Error getting filename from URL: {e}")
79
+ return None
80
 
81
  def get_supported_extensions():
82
+ """Returns a tuple of supported model file extensions."""
83
  return tuple([".ckpt", ".safetensors", ".pt", ".pth"])
84
 
85
  def download_model(url, dst, output_widget):
86
+ """Downloads a model from a URL to the specified destination."""
87
  filename = get_filename(url)
88
  filepath = os.path.join(dst, filename)
89
  try:
 
94
  if "/blob/" in url:
95
  url = url.replace("/blob/", "/resolve/")
96
  subprocess.run(["aria2c","-x 16",url,"-d",dst,"-o",filename])
97
+ return filepath
 
98
  except Exception as e:
99
+ print(f"Error downloading model: {e}")
100
+ return None
101
 
102
  def determine_load_checkpoint(model_to_load):
103
  """Determines if the model to load is a checkpoint, Diffusers model, or URL."""
104
+ try:
105
+ if is_valid_url(model_to_load) and (model_to_load.endswith(get_supported_extensions())):
106
+ return True
107
+ elif model_to_load.endswith(get_supported_extensions()):
108
+ return True
109
+ elif os.path.isdir(model_to_load):
110
+ required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
111
+ if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
112
+ return False
113
+ except Exception as e:
114
+ print(f"Error determining load checkpoint: {e}")
115
  return None # handle this case as required
116
 
117
  def create_model_repo(api, user, orgs_name, model_name, make_private=False):
118
  """Creates a Hugging Face model repository if it doesn't exist."""
 
 
 
 
 
119
  try:
120
+ if orgs_name == "":
121
+ repo_id = user["name"] + "/" + model_name.strip()
122
+ else:
123
+ repo_id = orgs_name + "/" + model_name.strip()
124
+
125
  validate_repo_id(repo_id)
126
  api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
127
  print(f"Model repo '{repo_id}' didn't exist, creating repo")
 
134
 
135
  def is_diffusers_model(model_path):
136
  """Checks if a given path is a valid Diffusers model directory."""
137
+ try:
138
+ required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
139
+ return required_folders.issubset(set(os.listdir(model_path))) and os.path.isfile(os.path.join(model_path, "model_index.json"))
140
+ except Exception as e:
141
+ print(f"Error checking if model is a Diffusers model: {e}")
142
+ return False
143
 
144
  # ---------------------- MODEL UTIL (From library.sdxl_model_util) ----------------------
145
 
146
  def load_models_from_sdxl_checkpoint(sdxl_base_id, checkpoint_path, device):
147
  """Loads SDXL model components from a checkpoint file."""
148
+ try:
149
+ text_encoder1 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder").to(device)
150
+ text_encoder2 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder_2").to(device)
151
+ vae = AutoencoderKL.from_pretrained(sdxl_base_id, subfolder="vae").to(device)
152
+ unet = UNet2DConditionModel.from_pretrained(sdxl_base_id, subfolder="unet").to(device)
153
+ unet = unet
154
+
155
+ ckpt_state_dict = torch.load(checkpoint_path, map_location=device)
156
+
157
+ o = OrderedDict()
158
+ for key in list(ckpt_state_dict.keys()):
159
+ o[key.replace("module.", "")] = ckpt_state_dict[key]
160
+ del ckpt_state_dict
161
+
162
+ print("Applying weights to text encoder 1:")
163
+ text_encoder1.load_state_dict({
164
+ '.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.cond_stage_model.model.transformer")
165
+ }, strict=False)
166
+ print("Applying weights to text encoder 2:")
167
+ text_encoder2.load_state_dict({
168
+ '.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("cond_stage_model.model.transformer")
169
+ }, strict=False)
170
+ print("Applying weights to VAE:")
171
+ vae.load_state_dict({
172
+ '.'.join(key.split('.')[2:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.model")
173
+ }, strict=False)
174
+ print("Applying weights to UNet:")
175
+ unet.load_state_dict({
176
+ key: o[key] for key in list(o.keys()) if key.startswith("model.diffusion_model")
177
+ }, strict=False)
178
+
179
+ logit_scale = None #Not used here!
180
+ global_step = None #Not used here!
181
+ return text_encoder1, text_encoder2, vae, unet, logit_scale, global_step
182
+ except Exception as e:
183
+ print(f"Error loading models from checkpoint: {e}")
184
+ return None
185
 
186
  def save_stable_diffusion_checkpoint(save_path, text_encoder1, text_encoder2, unet, epoch, global_step, ckpt_info, vae, logit_scale, save_dtype):
187
  """Saves the stable diffusion checkpoint."""
 
709
 
710
  # Create tempdir, will only be there for the function
711
  with tempfile.TemporaryDirectory() as output_path:
712
+ conversion_output = convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, hf_token, orgs_name, model_name, make_private)
713
 
714
  upload_output = upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
715