Old-Fisherman commited on
Commit
b922c99
·
verified ·
1 Parent(s): c9472c6

Update convert_mod.py

Browse files
Files changed (1) hide show
  1. convert_mod.py +261 -261
convert_mod.py CHANGED
@@ -1,261 +1,261 @@
1
- # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
2
- import os
3
- import torch
4
- import gguf # This needs to be the llama.cpp one specifically!
5
- import argparse
6
- from tqdm import tqdm
7
-
8
- from safetensors.torch import load_file
9
-
10
- QUANTIZATION_THRESHOLD = 1024
11
- REARRANGE_THRESHOLD = 512
12
- MAX_TENSOR_NAME_LENGTH = 127
13
-
14
- class ModelTemplate:
15
- arch = "invalid" # string describing architecture
16
- shape_fix = False # whether to reshape tensors
17
- keys_detect = [] # list of lists to match in state dict
18
- keys_banned = [] # list of keys that should mark model as invalid for conversion
19
-
20
- class ModelFlux(ModelTemplate):
21
- arch = "flux"
22
- keys_detect = [
23
- ("transformer_blocks.0.attn.norm_added_k.weight",),
24
- ("double_blocks.0.img_attn.proj.weight",),
25
- ]
26
- keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight",]
27
-
28
- class ModelSD3(ModelTemplate):
29
- arch = "sd3"
30
- keys_detect = [
31
- ("transformer_blocks.0.attn.add_q_proj.weight",),
32
- ("joint_blocks.0.x_block.attn.qkv.weight",),
33
- ]
34
- keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight",]
35
-
36
- class ModelAura(ModelTemplate):
37
- arch = "aura"
38
- keys_detect = [
39
- ("double_layers.3.modX.1.weight",),
40
- ("joint_transformer_blocks.3.ff_context.out_projection.weight",),
41
- ]
42
- keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight",]
43
-
44
- class ModelLTXV(ModelTemplate):
45
- arch = "ltxv"
46
- keys_detect = [
47
- (
48
- "adaln_single.emb.timestep_embedder.linear_2.weight",
49
- "transformer_blocks.27.scale_shift_table",
50
- "caption_projection.linear_2.weight",
51
- )
52
- ]
53
-
54
- class ModelSDXL(ModelTemplate):
55
- arch = "sdxl"
56
- shape_fix = True
57
- keys_detect = [
58
- ("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",),
59
- (
60
- "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight",
61
- "output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight",
62
- ), # Non-diffusers
63
- ("label_emb.0.0.weight",),
64
- ]
65
-
66
- class ModelSD1(ModelTemplate):
67
- arch = "sd1"
68
- shape_fix = True
69
- keys_detect = [
70
- ("down_blocks.0.downsamplers.0.conv.weight",),
71
- (
72
- "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight",
73
- "output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight"
74
- ), # Non-diffusers
75
- ]
76
-
77
- class ModelClipG(ModelTemplate):
78
- arch = "clip_g"
79
- keys_detect = [
80
- ("conditioner.embedders.1.model.ln_final.bias",), # Final layer normalization bias
81
- (
82
- "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", # Attention input projection weight
83
- "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", # Attention input projection weight for another block
84
- ),
85
- ]
86
- keys_banned = [] # Add any banned keys if necessary
87
-
88
- # The architectures are checked in order and the first successful match terminates the search.
89
- arch_list = [ModelFlux, ModelSD3, ModelAura, ModelLTXV, ModelSDXL, ModelSD1, ModelClipG]
90
-
91
- def is_model_arch(model, state_dict):
92
- # check if model is correct
93
- matched = False
94
- invalid = False
95
- for match_list in model.keys_detect:
96
- print(f"Checking match list: {match_list}")
97
- if all(key in state_dict for key in match_list):
98
- print(f"Match found for {match_list}")
99
- matched = True
100
- invalid = any(key in state_dict for key in model.keys_banned)
101
- break
102
- assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)"
103
- return matched
104
-
105
- def detect_arch(state_dict):
106
- model_arch = None
107
- for arch in arch_list:
108
- if is_model_arch(arch, state_dict):
109
- model_arch = arch
110
- break
111
- assert model_arch is not None, "Unknown model architecture!"
112
- return model_arch
113
-
114
- def parse_args():
115
- parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET")
116
- parser.add_argument("--src", required=True, help="Source model ckpt file.")
117
- parser.add_argument("--dst", help="Output unet gguf file.")
118
- args = parser.parse_args()
119
-
120
- if not os.path.isfile(args.src):
121
- parser.error("No input provided!")
122
-
123
- return args
124
-
125
- def load_state_dict(path):
126
- if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
127
- state_dict = torch.load(path, map_location="cpu", weights_only=True)
128
- state_dict = state_dict.get("model", state_dict)
129
- else:
130
- state_dict = load_file(path)
131
-
132
- # only keep unet with no prefix!
133
- prefix = None
134
- for pfx in ["model.diffusion_model.", "model."]:
135
- if any([x.startswith(pfx) for x in state_dict.keys()]):
136
- prefix = pfx
137
- break
138
-
139
- sd = {}
140
- for k, v in state_dict.items():
141
- if prefix and prefix not in k:
142
- continue
143
- if prefix:
144
- k = k.replace(prefix, "")
145
- sd[k] = v
146
-
147
- return sd
148
-
149
- def load_model(path):
150
- state_dict = load_state_dict(path)
151
- model_arch = detect_arch(state_dict)
152
- print(f"* Architecture detected from input: {model_arch.arch}")
153
- writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
154
- return (writer, state_dict, model_arch)
155
-
156
- def handle_tensors(args, writer, state_dict, model_arch):
157
- name_lengths = tuple(sorted(
158
- ((key, len(key)) for key in state_dict.keys()),
159
- key=lambda item: item[1],
160
- reverse=True,
161
- ))
162
- if not name_lengths:
163
- return
164
- max_name_len = name_lengths[0][1]
165
- if max_name_len > MAX_TENSOR_NAME_LENGTH:
166
- bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH)
167
- raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}")
168
- for key, data in tqdm(state_dict.items()):
169
- old_dtype = data.dtype
170
-
171
- if data.dtype == torch.bfloat16:
172
- data = data.to(torch.float32).numpy()
173
- # this is so we don't break torch 2.0.X
174
- elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]:
175
- data = data.to(torch.float16).numpy()
176
- else:
177
- data = data.numpy()
178
-
179
- n_dims = len(data.shape)
180
- data_shape = data.shape
181
- data_qtype = getattr(
182
- gguf.GGMLQuantizationType,
183
- "BF16" if old_dtype == torch.bfloat16 else "F16"
184
- )
185
-
186
- # get number of parameters (AKA elements) in this tensor
187
- n_params = 1
188
- for dim_size in data_shape:
189
- n_params *= dim_size
190
-
191
- # keys to keep as max precision
192
- blacklist = {
193
- "time_embedding.",
194
- "add_embedding.",
195
- "time_in.",
196
- "txt_in.",
197
- "vector_in.",
198
- "img_in.",
199
- "guidance_in.",
200
- "final_layer.",
201
- }
202
-
203
- if old_dtype in (torch.float32, torch.bfloat16):
204
- if n_dims == 1:
205
- # one-dimensional tensors should be kept in F32
206
- # also speeds up inference due to not dequantizing
207
- data_qtype = gguf.GGMLQuantizationType.F32
208
-
209
- elif n_params <= QUANTIZATION_THRESHOLD:
210
- # very small tensors
211
- data_qtype = gguf.GGMLQuantizationType.F32
212
-
213
- elif ".weight" in key and any(x in key for x in blacklist):
214
- data_qtype = gguf.GGMLQuantizationType.F32
215
-
216
- if (model_arch.shape_fix # NEVER reshape for models such as flux
217
- and n_dims > 1 # Skip one-dimensional tensors
218
- and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement
219
- and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256
220
- and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256
221
- ):
222
- orig_shape = data.shape
223
- data = data.reshape(n_params // 256, 256)
224
- writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape))
225
-
226
- try:
227
- data = gguf.quants.quantize(data, data_qtype)
228
- except (AttributeError, gguf.QuantError) as e:
229
- tqdm.write(f"falling back to F16: {e}")
230
- data_qtype = gguf.GGMLQuantizationType.F16
231
- data = gguf.quants.quantize(data, data_qtype)
232
-
233
- new_name = key # do we need to rename?
234
-
235
- shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
236
- tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
237
-
238
- writer.add_tensor(new_name, data, raw_dtype=data_qtype)
239
-
240
- if __name__ == "__main__":
241
- args = parse_args()
242
- path = args.src
243
- writer, state_dict, model_arch = load_model(path)
244
-
245
- writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
246
- if next(iter(state_dict.values())).dtype == torch.bfloat16:
247
- out_path = f"{os.path.splitext(path)[0]}-BF16.gguf"
248
- writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16)
249
- else:
250
- out_path = f"{os.path.splitext(path)[0]}-F16.gguf"
251
- writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)
252
-
253
- out_path = args.dst or out_path
254
- if os.path.isfile(out_path):
255
- input("Output exists enter to continue or ctrl+c to abort!")
256
-
257
- handle_tensors(path, writer, state_dict, model_arch)
258
- writer.write_header_to_file(path=out_path)
259
- writer.write_kv_data_to_file()
260
- writer.write_tensors_to_file(progress=True)
261
- writer.close()
 
1
+ # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
2
+ import os
3
+ import torch
4
+ import gguf # This needs to be the llama.cpp one specifically!
5
+ import argparse
6
+ from tqdm import tqdm
7
+
8
+ from safetensors.torch import load_file
9
+
10
+ QUANTIZATION_THRESHOLD = 1024
11
+ REARRANGE_THRESHOLD = 512
12
+ MAX_TENSOR_NAME_LENGTH = 127
13
+
14
+ class ModelTemplate:
15
+ arch = "invalid" # string describing architecture
16
+ shape_fix = False # whether to reshape tensors
17
+ keys_detect = [] # list of lists to match in state dict
18
+ keys_banned = [] # list of keys that should mark model as invalid for conversion
19
+
20
+ class ModelFlux(ModelTemplate):
21
+ arch = "flux"
22
+ keys_detect = [
23
+ ("transformer_blocks.0.attn.norm_added_k.weight",),
24
+ ("double_blocks.0.img_attn.proj.weight",),
25
+ ]
26
+ keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight",]
27
+
28
+ class ModelSD3(ModelTemplate):
29
+ arch = "sd3"
30
+ keys_detect = [
31
+ ("transformer_blocks.0.attn.add_q_proj.weight",),
32
+ ("joint_blocks.0.x_block.attn.qkv.weight",),
33
+ ]
34
+ keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight",]
35
+
36
+ class ModelAura(ModelTemplate):
37
+ arch = "aura"
38
+ keys_detect = [
39
+ ("double_layers.3.modX.1.weight",),
40
+ ("joint_transformer_blocks.3.ff_context.out_projection.weight",),
41
+ ]
42
+ keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight",]
43
+
44
+ class ModelLTXV(ModelTemplate):
45
+ arch = "ltxv"
46
+ keys_detect = [
47
+ (
48
+ "adaln_single.emb.timestep_embedder.linear_2.weight",
49
+ "transformer_blocks.27.scale_shift_table",
50
+ "caption_projection.linear_2.weight",
51
+ )
52
+ ]
53
+
54
+ class ModelSDXL(ModelTemplate):
55
+ arch = "sdxl"
56
+ shape_fix = True
57
+ keys_detect = [
58
+ ("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",),
59
+ (
60
+ "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight",
61
+ "output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight",
62
+ ), # Non-diffusers
63
+ ("label_emb.0.0.weight",),
64
+ ]
65
+
66
+ class ModelSD1(ModelTemplate):
67
+ arch = "sd1"
68
+ shape_fix = True
69
+ keys_detect = [
70
+ ("down_blocks.0.downsamplers.0.conv.weight",),
71
+ (
72
+ "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight",
73
+ "output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight"
74
+ ), # Non-diffusers
75
+ ]
76
+
77
+ class ModelClipG(ModelTemplate):
78
+ arch = "clip_g"
79
+ keys_detect = [
80
+ ("text_model.encoder.layers.22.ln_final.bias",), # Example key, adjust as needed
81
+ (
82
+ "text_model.encoder.layers.0.attn.in_proj_weight", # Example key, adjust as needed
83
+ "text_model.encoder.layers.1.attn.in_proj_weight", # Example key, adjust as needed
84
+ ),
85
+ ]
86
+ keys_banned = [] # Add any banned keys if necessary
87
+
88
+ # The architectures are checked in order and the first successful match terminates the search.
89
+ arch_list = [ModelFlux, ModelSD3, ModelAura, ModelLTXV, ModelSDXL, ModelSD1, ModelClipG]
90
+
91
+ def is_model_arch(model, state_dict):
92
+ # check if model is correct
93
+ matched = False
94
+ invalid = False
95
+ for match_list in model.keys_detect:
96
+ print(f"Checking match list: {match_list}")
97
+ if all(key in state_dict for key in match_list):
98
+ print(f"Match found for {match_list}")
99
+ matched = True
100
+ invalid = any(key in state_dict for key in model.keys_banned)
101
+ break
102
+ assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)"
103
+ return matched
104
+
105
+ def detect_arch(state_dict):
106
+ model_arch = None
107
+ for arch in arch_list:
108
+ if is_model_arch(arch, state_dict):
109
+ model_arch = arch
110
+ break
111
+ assert model_arch is not None, "Unknown model architecture!"
112
+ return model_arch
113
+
114
+ def parse_args():
115
+ parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET")
116
+ parser.add_argument("--src", required=True, help="Source model ckpt file.")
117
+ parser.add_argument("--dst", help="Output unet gguf file.")
118
+ args = parser.parse_args()
119
+
120
+ if not os.path.isfile(args.src):
121
+ parser.error("No input provided!")
122
+
123
+ return args
124
+
125
+ def load_state_dict(path):
126
+ if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
127
+ state_dict = torch.load(path, map_location="cpu", weights_only=True)
128
+ state_dict = state_dict.get("model", state_dict)
129
+ else:
130
+ state_dict = load_file(path)
131
+
132
+ # only keep unet with no prefix!
133
+ prefix = None
134
+ for pfx in ["model.diffusion_model.", "model."]:
135
+ if any([x.startswith(pfx) for x in state_dict.keys()]):
136
+ prefix = pfx
137
+ break
138
+
139
+ sd = {}
140
+ for k, v in state_dict.items():
141
+ if prefix and prefix not in k:
142
+ continue
143
+ if prefix:
144
+ k = k.replace(prefix, "")
145
+ sd[k] = v
146
+
147
+ return sd
148
+
149
+ def load_model(path):
150
+ state_dict = load_state_dict(path)
151
+ model_arch = detect_arch(state_dict)
152
+ print(f"* Architecture detected from input: {model_arch.arch}")
153
+ writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
154
+ return (writer, state_dict, model_arch)
155
+
156
+ def handle_tensors(args, writer, state_dict, model_arch):
157
+ name_lengths = tuple(sorted(
158
+ ((key, len(key)) for key in state_dict.keys()),
159
+ key=lambda item: item[1],
160
+ reverse=True,
161
+ ))
162
+ if not name_lengths:
163
+ return
164
+ max_name_len = name_lengths[0][1]
165
+ if max_name_len > MAX_TENSOR_NAME_LENGTH:
166
+ bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH)
167
+ raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}")
168
+ for key, data in tqdm(state_dict.items()):
169
+ old_dtype = data.dtype
170
+
171
+ if data.dtype == torch.bfloat16:
172
+ data = data.to(torch.float32).numpy()
173
+ # this is so we don't break torch 2.0.X
174
+ elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]:
175
+ data = data.to(torch.float16).numpy()
176
+ else:
177
+ data = data.numpy()
178
+
179
+ n_dims = len(data.shape)
180
+ data_shape = data.shape
181
+ data_qtype = getattr(
182
+ gguf.GGMLQuantizationType,
183
+ "BF16" if old_dtype == torch.bfloat16 else "F16"
184
+ )
185
+
186
+ # get number of parameters (AKA elements) in this tensor
187
+ n_params = 1
188
+ for dim_size in data_shape:
189
+ n_params *= dim_size
190
+
191
+ # keys to keep as max precision
192
+ blacklist = {
193
+ "time_embedding.",
194
+ "add_embedding.",
195
+ "time_in.",
196
+ "txt_in.",
197
+ "vector_in.",
198
+ "img_in.",
199
+ "guidance_in.",
200
+ "final_layer.",
201
+ }
202
+
203
+ if old_dtype in (torch.float32, torch.bfloat16):
204
+ if n_dims == 1:
205
+ # one-dimensional tensors should be kept in F32
206
+ # also speeds up inference due to not dequantizing
207
+ data_qtype = gguf.GGMLQuantizationType.F32
208
+
209
+ elif n_params <= QUANTIZATION_THRESHOLD:
210
+ # very small tensors
211
+ data_qtype = gguf.GGMLQuantizationType.F32
212
+
213
+ elif ".weight" in key and any(x in key for x in blacklist):
214
+ data_qtype = gguf.GGMLQuantizationType.F32
215
+
216
+ if (model_arch.shape_fix # NEVER reshape for models such as flux
217
+ and n_dims > 1 # Skip one-dimensional tensors
218
+ and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement
219
+ and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256
220
+ and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256
221
+ ):
222
+ orig_shape = data.shape
223
+ data = data.reshape(n_params // 256, 256)
224
+ writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape))
225
+
226
+ try:
227
+ data = gguf.quants.quantize(data, data_qtype)
228
+ except (AttributeError, gguf.QuantError) as e:
229
+ tqdm.write(f"falling back to F16: {e}")
230
+ data_qtype = gguf.GGMLQuantizationType.F16
231
+ data = gguf.quants.quantize(data, data_qtype)
232
+
233
+ new_name = key # do we need to rename?
234
+
235
+ shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
236
+ tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
237
+
238
+ writer.add_tensor(new_name, data, raw_dtype=data_qtype)
239
+
240
+ if __name__ == "__main__":
241
+ args = parse_args()
242
+ path = args.src
243
+ writer, state_dict, model_arch = load_model(path)
244
+
245
+ writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
246
+ if next(iter(state_dict.values())).dtype == torch.bfloat16:
247
+ out_path = f"{os.path.splitext(path)[0]}-BF16.gguf"
248
+ writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16)
249
+ else:
250
+ out_path = f"{os.path.splitext(path)[0]}-F16.gguf"
251
+ writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)
252
+
253
+ out_path = args.dst or out_path
254
+ if os.path.isfile(out_path):
255
+ input("Output exists enter to continue or ctrl+c to abort!")
256
+
257
+ handle_tensors(path, writer, state_dict, model_arch)
258
+ writer.write_header_to_file(path=out_path)
259
+ writer.write_kv_data_to_file()
260
+ writer.write_tensors_to_file(progress=True)
261
+ writer.close()