Old-Fisherman commited on
Commit
e1ebdf9
·
verified ·
1 Parent(s): adfffe7

Rename convert_mod.py to convert_g.py

Browse files
Files changed (2) hide show
  1. convert_g.py +141 -0
  2. convert_mod.py +0 -261
convert_g.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import torch
4
+ import gguf # llama.cpp's specific gguf python module
5
+ import argparse
6
+ from tqdm import tqdm
7
+ from safetensors.torch import load_file
8
+
9
+ # Configuration constants
10
+ QUANTIZATION_THRESHOLD = 1024
11
+ MAX_TENSOR_NAME_LENGTH = 127
12
+
13
+ # Base model template class
14
+ class ModelTemplate:
15
+ arch = "invalid"
16
+ shape_fix = False
17
+ keys_detect = []
18
+ keys_banned = []
19
+
20
+ # Specific template for clip_g using ComfyUI standard keys
21
+ class ModelClipG(ModelTemplate):
22
+ arch = "clip_g"
23
+ shape_fix = False # No rearrangement for text encoder models
24
+ keys_detect = [
25
+ ("logit_scale",),
26
+ ("text_model.embeddings.position_embedding.weight",),
27
+ ("text_model.encoder.layers.0.self_attn.in_proj_weight",),
28
+ ]
29
+ keys_banned = []
30
+
31
+ # Only clip_g in this conversion script
32
+ arch_list = [ModelClipG]
33
+
34
+ def is_model_arch(model, state_dict):
35
+ for key_tuple in model.keys_detect:
36
+ if all(key in state_dict for key in key_tuple):
37
+ # Optionally check for banned keys
38
+ if any(key in state_dict for key in model.keys_banned):
39
+ raise ValueError("Model architecture not allowed for conversion!")
40
+ return True
41
+ return False
42
+
43
+ def detect_arch(state_dict):
44
+ for model in arch_list:
45
+ if is_model_arch(model, state_dict):
46
+ return model
47
+ raise ValueError("Unknown model architecture!")
48
+
49
+ def parse_args():
50
+ parser = argparse.ArgumentParser(description="Convert clip_g model (ComfyUI standard) to GGUF")
51
+ parser.add_argument("--src", required=True, help="Source model file (.safetensors, .pt, etc)")
52
+ parser.add_argument("--dst", help="Output GGUF file")
53
+ return parser.parse_args()
54
+
55
+ def load_state_dict(path):
56
+ if any(path.endswith(ext) for ext in [".ckpt", ".pt", ".bin", ".pth"]):
57
+ state_dict = torch.load(path, map_location="cpu", weights_only=True)
58
+ state_dict = state_dict.get("model", state_dict)
59
+ else:
60
+ state_dict = load_file(path)
61
+
62
+ # Remove unwanted prefixes if they exist.
63
+ prefix = None
64
+ for pfx in ["model.diffusion_model.", "model."]:
65
+ if any(k.startswith(pfx) for k in state_dict.keys()):
66
+ prefix = pfx
67
+ break
68
+ new_state = {}
69
+ for k, v in state_dict.items():
70
+ if prefix:
71
+ if not k.startswith(prefix):
72
+ continue
73
+ k = k.replace(prefix, "")
74
+ new_state[k] = v
75
+ return new_state
76
+
77
+ def load_model(path):
78
+ state_dict = load_state_dict(path)
79
+ model_arch = detect_arch(state_dict)
80
+ print(f"Detected architecture: {model_arch.arch}")
81
+ writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
82
+ return writer, state_dict, model_arch
83
+
84
+ def handle_tensors(writer, state_dict, model_arch):
85
+ # Check that all tensor names are within allowed length.
86
+ for key in state_dict.keys():
87
+ if len(key) > MAX_TENSOR_NAME_LENGTH:
88
+ raise ValueError(f"Tensor name {key} exceeds maximum length {MAX_TENSOR_NAME_LENGTH}")
89
+
90
+ for key, tensor in tqdm(state_dict.items(), desc="Processing tensors"):
91
+ if isinstance(tensor, torch.Tensor):
92
+ data = tensor.detach().cpu().numpy()
93
+ else:
94
+ data = tensor
95
+
96
+ # Determine quantization based on key name
97
+ key_lower = key.lower()
98
+ if data.ndim == 1 or "bias" in key_lower or "layer_norm" in key_lower or "ln_" in key_lower:
99
+ data_qtype = gguf.GGMLQuantizationType.F32
100
+ elif "embeddings" in key_lower:
101
+ data_qtype = gguf.GGMLQuantizationType.F32
102
+ else:
103
+ data_qtype = gguf.GGMLQuantizationType.F16
104
+
105
+ if data.size <= QUANTIZATION_THRESHOLD:
106
+ data_qtype = gguf.GGMLQuantizationType.F32
107
+
108
+ try:
109
+ quantized = gguf.quants.quantize(data, data_qtype)
110
+ except Exception as e:
111
+ tqdm.write(f"Quantization failed for {key} with error {e}; falling back to F16")
112
+ data_qtype = gguf.GGMLQuantizationType.F16
113
+ quantized = gguf.quants.quantize(data, data_qtype)
114
+
115
+ writer.add_tensor(key, quantized, raw_dtype=data_qtype)
116
+ tqdm.write(f"Processed {key}: {data.dtype} -> {data_qtype.name}, shape = {data.shape}")
117
+
118
+ def main():
119
+ args = parse_args()
120
+ writer, state_dict, model_arch = load_model(args.src)
121
+
122
+ # Determine file type based on first tensor's dtype.
123
+ first_tensor = next(iter(state_dict.values()))
124
+ if first_tensor.dtype == torch.bfloat16:
125
+ out_path = args.dst or os.path.splitext(args.src)[0] + "-BF16.gguf"
126
+ writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16)
127
+ else:
128
+ out_path = args.dst or os.path.splitext(args.src)[0] + "-F16.gguf"
129
+ writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)
130
+
131
+ if os.path.isfile(out_path):
132
+ input("Output exists. Press enter to continue or Ctrl+C to abort")
133
+
134
+ handle_tensors(writer, state_dict, model_arch)
135
+ writer.write_header_to_file(path=out_path)
136
+ writer.write_kv_data_to_file()
137
+ writer.write_tensors_to_file(progress=True)
138
+ writer.close()
139
+
140
+ if __name__ == "__main__":
141
+ main()
convert_mod.py DELETED
@@ -1,261 +0,0 @@
1
- # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
2
- import os
3
- import torch
4
- import gguf # This needs to be the llama.cpp one specifically!
5
- import argparse
6
- from tqdm import tqdm
7
-
8
- from safetensors.torch import load_file
9
-
10
- QUANTIZATION_THRESHOLD = 1024
11
- REARRANGE_THRESHOLD = 512
12
- MAX_TENSOR_NAME_LENGTH = 127
13
-
14
- class ModelTemplate:
15
- arch = "invalid" # string describing architecture
16
- shape_fix = False # whether to reshape tensors
17
- keys_detect = [] # list of lists to match in state dict
18
- keys_banned = [] # list of keys that should mark model as invalid for conversion
19
-
20
- class ModelFlux(ModelTemplate):
21
- arch = "flux"
22
- keys_detect = [
23
- ("transformer_blocks.0.attn.norm_added_k.weight",),
24
- ("double_blocks.0.img_attn.proj.weight",),
25
- ]
26
- keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight",]
27
-
28
- class ModelSD3(ModelTemplate):
29
- arch = "sd3"
30
- keys_detect = [
31
- ("transformer_blocks.0.attn.add_q_proj.weight",),
32
- ("joint_blocks.0.x_block.attn.qkv.weight",),
33
- ]
34
- keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight",]
35
-
36
- class ModelAura(ModelTemplate):
37
- arch = "aura"
38
- keys_detect = [
39
- ("double_layers.3.modX.1.weight",),
40
- ("joint_transformer_blocks.3.ff_context.out_projection.weight",),
41
- ]
42
- keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight",]
43
-
44
- class ModelLTXV(ModelTemplate):
45
- arch = "ltxv"
46
- keys_detect = [
47
- (
48
- "adaln_single.emb.timestep_embedder.linear_2.weight",
49
- "transformer_blocks.27.scale_shift_table",
50
- "caption_projection.linear_2.weight",
51
- )
52
- ]
53
-
54
- class ModelSDXL(ModelTemplate):
55
- arch = "sdxl"
56
- shape_fix = True
57
- keys_detect = [
58
- ("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",),
59
- (
60
- "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight",
61
- "output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight",
62
- ), # Non-diffusers
63
- ("label_emb.0.0.weight",),
64
- ]
65
-
66
- class ModelSD1(ModelTemplate):
67
- arch = "sd1"
68
- shape_fix = True
69
- keys_detect = [
70
- ("down_blocks.0.downsamplers.0.conv.weight",),
71
- (
72
- "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight",
73
- "output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight"
74
- ), # Non-diffusers
75
- ]
76
-
77
- class ModelClipG(ModelTemplate):
78
- arch = "clip_g"
79
- keys_detect = [
80
- ("text_model.final_layer_norm.bias",), # Example key, adjust as needed
81
- (
82
- "text_model.encoder.layers.0.self_attn.k_proj.weight", # Example key, adjust as needed
83
- "text_model.encoder.layers.1.self_attn.q_proj.bias", # Example key, adjust as needed
84
- ),
85
- ]
86
- keys_banned = [] # Add any banned keys if necessary
87
-
88
- # The architectures are checked in order and the first successful match terminates the search.
89
- arch_list = [ModelFlux, ModelSD3, ModelAura, ModelLTXV, ModelSDXL, ModelSD1, ModelClipG]
90
-
91
- def is_model_arch(model, state_dict):
92
- # check if model is correct
93
- matched = False
94
- invalid = False
95
- for match_list in model.keys_detect:
96
- print(f"Checking match list: {match_list}")
97
- if all(key in state_dict for key in match_list):
98
- print(f"Match found for {match_list}")
99
- matched = True
100
- invalid = any(key in state_dict for key in model.keys_banned)
101
- break
102
- assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)"
103
- return matched
104
-
105
- def detect_arch(state_dict):
106
- model_arch = None
107
- for arch in arch_list:
108
- if is_model_arch(arch, state_dict):
109
- model_arch = arch
110
- break
111
- assert model_arch is not None, "Unknown model architecture!"
112
- return model_arch
113
-
114
- def parse_args():
115
- parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET")
116
- parser.add_argument("--src", required=True, help="Source model ckpt file.")
117
- parser.add_argument("--dst", help="Output unet gguf file.")
118
- args = parser.parse_args()
119
-
120
- if not os.path.isfile(args.src):
121
- parser.error("No input provided!")
122
-
123
- return args
124
-
125
- def load_state_dict(path):
126
- if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
127
- state_dict = torch.load(path, map_location="cpu", weights_only=True)
128
- state_dict = state_dict.get("model", state_dict)
129
- else:
130
- state_dict = load_file(path)
131
-
132
- # only keep unet with no prefix!
133
- prefix = None
134
- for pfx in ["model.diffusion_model.", "model."]:
135
- if any([x.startswith(pfx) for x in state_dict.keys()]):
136
- prefix = pfx
137
- break
138
-
139
- sd = {}
140
- for k, v in state_dict.items():
141
- if prefix and prefix not in k:
142
- continue
143
- if prefix:
144
- k = k.replace(prefix, "")
145
- sd[k] = v
146
-
147
- return sd
148
-
149
- def load_model(path):
150
- state_dict = load_state_dict(path)
151
- model_arch = detect_arch(state_dict)
152
- print(f"* Architecture detected from input: {model_arch.arch}")
153
- writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
154
- return (writer, state_dict, model_arch)
155
-
156
- def handle_tensors(args, writer, state_dict, model_arch):
157
- name_lengths = tuple(sorted(
158
- ((key, len(key)) for key in state_dict.keys()),
159
- key=lambda item: item[1],
160
- reverse=True,
161
- ))
162
- if not name_lengths:
163
- return
164
- max_name_len = name_lengths[0][1]
165
- if max_name_len > MAX_TENSOR_NAME_LENGTH:
166
- bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH)
167
- raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}")
168
- for key, data in tqdm(state_dict.items()):
169
- old_dtype = data.dtype
170
-
171
- if data.dtype == torch.bfloat16:
172
- data = data.to(torch.float32).numpy()
173
- # this is so we don't break torch 2.0.X
174
- elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]:
175
- data = data.to(torch.float16).numpy()
176
- else:
177
- data = data.numpy()
178
-
179
- n_dims = len(data.shape)
180
- data_shape = data.shape
181
- data_qtype = getattr(
182
- gguf.GGMLQuantizationType,
183
- "BF16" if old_dtype == torch.bfloat16 else "F16"
184
- )
185
-
186
- # get number of parameters (AKA elements) in this tensor
187
- n_params = 1
188
- for dim_size in data_shape:
189
- n_params *= dim_size
190
-
191
- # keys to keep as max precision
192
- blacklist = {
193
- "time_embedding.",
194
- "add_embedding.",
195
- "time_in.",
196
- "txt_in.",
197
- "vector_in.",
198
- "img_in.",
199
- "guidance_in.",
200
- "final_layer.",
201
- }
202
-
203
- if old_dtype in (torch.float32, torch.bfloat16):
204
- if n_dims == 1:
205
- # one-dimensional tensors should be kept in F32
206
- # also speeds up inference due to not dequantizing
207
- data_qtype = gguf.GGMLQuantizationType.F32
208
-
209
- elif n_params <= QUANTIZATION_THRESHOLD:
210
- # very small tensors
211
- data_qtype = gguf.GGMLQuantizationType.F32
212
-
213
- elif ".weight" in key and any(x in key for x in blacklist):
214
- data_qtype = gguf.GGMLQuantizationType.F32
215
-
216
- if (model_arch.shape_fix # NEVER reshape for models such as flux
217
- and n_dims > 1 # Skip one-dimensional tensors
218
- and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement
219
- and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256
220
- and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256
221
- ):
222
- orig_shape = data.shape
223
- data = data.reshape(n_params // 256, 256)
224
- writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape))
225
-
226
- try:
227
- data = gguf.quants.quantize(data, data_qtype)
228
- except (AttributeError, gguf.QuantError) as e:
229
- tqdm.write(f"falling back to F16: {e}")
230
- data_qtype = gguf.GGMLQuantizationType.F16
231
- data = gguf.quants.quantize(data, data_qtype)
232
-
233
- new_name = key # do we need to rename?
234
-
235
- shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
236
- tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
237
-
238
- writer.add_tensor(new_name, data, raw_dtype=data_qtype)
239
-
240
- if __name__ == "__main__":
241
- args = parse_args()
242
- path = args.src
243
- writer, state_dict, model_arch = load_model(path)
244
-
245
- writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
246
- if next(iter(state_dict.values())).dtype == torch.bfloat16:
247
- out_path = f"{os.path.splitext(path)[0]}-BF16.gguf"
248
- writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16)
249
- else:
250
- out_path = f"{os.path.splitext(path)[0]}-F16.gguf"
251
- writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)
252
-
253
- out_path = args.dst or out_path
254
- if os.path.isfile(out_path):
255
- input("Output exists enter to continue or ctrl+c to abort!")
256
-
257
- handle_tensors(path, writer, state_dict, model_arch)
258
- writer.write_header_to_file(path=out_path)
259
- writer.write_kv_data_to_file()
260
- writer.write_tensors_to_file(progress=True)
261
- writer.close()