SDXL_Finetune_GGUF_Files / convert_g.py
Old-Fisherman's picture
Rename convert_mod.py to convert_g.py
e1ebdf9 verified
#!/usr/bin/env python3
import os
import torch
import gguf # llama.cpp's specific gguf python module
import argparse
from tqdm import tqdm
from safetensors.torch import load_file
# Configuration constants
QUANTIZATION_THRESHOLD = 1024
MAX_TENSOR_NAME_LENGTH = 127
# Base model template class
class ModelTemplate:
arch = "invalid"
shape_fix = False
keys_detect = []
keys_banned = []
# Specific template for clip_g using ComfyUI standard keys
class ModelClipG(ModelTemplate):
arch = "clip_g"
shape_fix = False # No rearrangement for text encoder models
keys_detect = [
("logit_scale",),
("text_model.embeddings.position_embedding.weight",),
("text_model.encoder.layers.0.self_attn.in_proj_weight",),
]
keys_banned = []
# Only clip_g in this conversion script
arch_list = [ModelClipG]
def is_model_arch(model, state_dict):
for key_tuple in model.keys_detect:
if all(key in state_dict for key in key_tuple):
# Optionally check for banned keys
if any(key in state_dict for key in model.keys_banned):
raise ValueError("Model architecture not allowed for conversion!")
return True
return False
def detect_arch(state_dict):
for model in arch_list:
if is_model_arch(model, state_dict):
return model
raise ValueError("Unknown model architecture!")
def parse_args():
parser = argparse.ArgumentParser(description="Convert clip_g model (ComfyUI standard) to GGUF")
parser.add_argument("--src", required=True, help="Source model file (.safetensors, .pt, etc)")
parser.add_argument("--dst", help="Output GGUF file")
return parser.parse_args()
def load_state_dict(path):
if any(path.endswith(ext) for ext in [".ckpt", ".pt", ".bin", ".pth"]):
state_dict = torch.load(path, map_location="cpu", weights_only=True)
state_dict = state_dict.get("model", state_dict)
else:
state_dict = load_file(path)
# Remove unwanted prefixes if they exist.
prefix = None
for pfx in ["model.diffusion_model.", "model."]:
if any(k.startswith(pfx) for k in state_dict.keys()):
prefix = pfx
break
new_state = {}
for k, v in state_dict.items():
if prefix:
if not k.startswith(prefix):
continue
k = k.replace(prefix, "")
new_state[k] = v
return new_state
def load_model(path):
state_dict = load_state_dict(path)
model_arch = detect_arch(state_dict)
print(f"Detected architecture: {model_arch.arch}")
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
return writer, state_dict, model_arch
def handle_tensors(writer, state_dict, model_arch):
# Check that all tensor names are within allowed length.
for key in state_dict.keys():
if len(key) > MAX_TENSOR_NAME_LENGTH:
raise ValueError(f"Tensor name {key} exceeds maximum length {MAX_TENSOR_NAME_LENGTH}")
for key, tensor in tqdm(state_dict.items(), desc="Processing tensors"):
if isinstance(tensor, torch.Tensor):
data = tensor.detach().cpu().numpy()
else:
data = tensor
# Determine quantization based on key name
key_lower = key.lower()
if data.ndim == 1 or "bias" in key_lower or "layer_norm" in key_lower or "ln_" in key_lower:
data_qtype = gguf.GGMLQuantizationType.F32
elif "embeddings" in key_lower:
data_qtype = gguf.GGMLQuantizationType.F32
else:
data_qtype = gguf.GGMLQuantizationType.F16
if data.size <= QUANTIZATION_THRESHOLD:
data_qtype = gguf.GGMLQuantizationType.F32
try:
quantized = gguf.quants.quantize(data, data_qtype)
except Exception as e:
tqdm.write(f"Quantization failed for {key} with error {e}; falling back to F16")
data_qtype = gguf.GGMLQuantizationType.F16
quantized = gguf.quants.quantize(data, data_qtype)
writer.add_tensor(key, quantized, raw_dtype=data_qtype)
tqdm.write(f"Processed {key}: {data.dtype} -> {data_qtype.name}, shape = {data.shape}")
def main():
args = parse_args()
writer, state_dict, model_arch = load_model(args.src)
# Determine file type based on first tensor's dtype.
first_tensor = next(iter(state_dict.values()))
if first_tensor.dtype == torch.bfloat16:
out_path = args.dst or os.path.splitext(args.src)[0] + "-BF16.gguf"
writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16)
else:
out_path = args.dst or os.path.splitext(args.src)[0] + "-F16.gguf"
writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)
if os.path.isfile(out_path):
input("Output exists. Press enter to continue or Ctrl+C to abort")
handle_tensors(writer, state_dict, model_arch)
writer.write_header_to_file(path=out_path)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()
if __name__ == "__main__":
main()