File size: 5,077 Bytes
e1ebdf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#!/usr/bin/env python3
import os
import torch
import gguf # llama.cpp's specific gguf python module
import argparse
from tqdm import tqdm
from safetensors.torch import load_file
# Configuration constants
QUANTIZATION_THRESHOLD = 1024
MAX_TENSOR_NAME_LENGTH = 127
# Base model template class
class ModelTemplate:
arch = "invalid"
shape_fix = False
keys_detect = []
keys_banned = []
# Specific template for clip_g using ComfyUI standard keys
class ModelClipG(ModelTemplate):
arch = "clip_g"
shape_fix = False # No rearrangement for text encoder models
keys_detect = [
("logit_scale",),
("text_model.embeddings.position_embedding.weight",),
("text_model.encoder.layers.0.self_attn.in_proj_weight",),
]
keys_banned = []
# Only clip_g in this conversion script
arch_list = [ModelClipG]
def is_model_arch(model, state_dict):
for key_tuple in model.keys_detect:
if all(key in state_dict for key in key_tuple):
# Optionally check for banned keys
if any(key in state_dict for key in model.keys_banned):
raise ValueError("Model architecture not allowed for conversion!")
return True
return False
def detect_arch(state_dict):
for model in arch_list:
if is_model_arch(model, state_dict):
return model
raise ValueError("Unknown model architecture!")
def parse_args():
parser = argparse.ArgumentParser(description="Convert clip_g model (ComfyUI standard) to GGUF")
parser.add_argument("--src", required=True, help="Source model file (.safetensors, .pt, etc)")
parser.add_argument("--dst", help="Output GGUF file")
return parser.parse_args()
def load_state_dict(path):
if any(path.endswith(ext) for ext in [".ckpt", ".pt", ".bin", ".pth"]):
state_dict = torch.load(path, map_location="cpu", weights_only=True)
state_dict = state_dict.get("model", state_dict)
else:
state_dict = load_file(path)
# Remove unwanted prefixes if they exist.
prefix = None
for pfx in ["model.diffusion_model.", "model."]:
if any(k.startswith(pfx) for k in state_dict.keys()):
prefix = pfx
break
new_state = {}
for k, v in state_dict.items():
if prefix:
if not k.startswith(prefix):
continue
k = k.replace(prefix, "")
new_state[k] = v
return new_state
def load_model(path):
state_dict = load_state_dict(path)
model_arch = detect_arch(state_dict)
print(f"Detected architecture: {model_arch.arch}")
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
return writer, state_dict, model_arch
def handle_tensors(writer, state_dict, model_arch):
# Check that all tensor names are within allowed length.
for key in state_dict.keys():
if len(key) > MAX_TENSOR_NAME_LENGTH:
raise ValueError(f"Tensor name {key} exceeds maximum length {MAX_TENSOR_NAME_LENGTH}")
for key, tensor in tqdm(state_dict.items(), desc="Processing tensors"):
if isinstance(tensor, torch.Tensor):
data = tensor.detach().cpu().numpy()
else:
data = tensor
# Determine quantization based on key name
key_lower = key.lower()
if data.ndim == 1 or "bias" in key_lower or "layer_norm" in key_lower or "ln_" in key_lower:
data_qtype = gguf.GGMLQuantizationType.F32
elif "embeddings" in key_lower:
data_qtype = gguf.GGMLQuantizationType.F32
else:
data_qtype = gguf.GGMLQuantizationType.F16
if data.size <= QUANTIZATION_THRESHOLD:
data_qtype = gguf.GGMLQuantizationType.F32
try:
quantized = gguf.quants.quantize(data, data_qtype)
except Exception as e:
tqdm.write(f"Quantization failed for {key} with error {e}; falling back to F16")
data_qtype = gguf.GGMLQuantizationType.F16
quantized = gguf.quants.quantize(data, data_qtype)
writer.add_tensor(key, quantized, raw_dtype=data_qtype)
tqdm.write(f"Processed {key}: {data.dtype} -> {data_qtype.name}, shape = {data.shape}")
def main():
args = parse_args()
writer, state_dict, model_arch = load_model(args.src)
# Determine file type based on first tensor's dtype.
first_tensor = next(iter(state_dict.values()))
if first_tensor.dtype == torch.bfloat16:
out_path = args.dst or os.path.splitext(args.src)[0] + "-BF16.gguf"
writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16)
else:
out_path = args.dst or os.path.splitext(args.src)[0] + "-F16.gguf"
writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)
if os.path.isfile(out_path):
input("Output exists. Press enter to continue or Ctrl+C to abort")
handle_tensors(writer, state_dict, model_arch)
writer.write_header_to_file(path=out_path)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()
if __name__ == "__main__":
main() |