File size: 5,077 Bytes
e1ebdf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
import os
import torch
import gguf  # llama.cpp's specific gguf python module
import argparse
from tqdm import tqdm
from safetensors.torch import load_file

# Configuration constants
QUANTIZATION_THRESHOLD = 1024
MAX_TENSOR_NAME_LENGTH = 127

# Base model template class
class ModelTemplate:
    arch = "invalid"
    shape_fix = False
    keys_detect = []
    keys_banned = []

# Specific template for clip_g using ComfyUI standard keys
class ModelClipG(ModelTemplate):
    arch = "clip_g"
    shape_fix = False  # No rearrangement for text encoder models
    keys_detect = [
        ("logit_scale",),
        ("text_model.embeddings.position_embedding.weight",),
        ("text_model.encoder.layers.0.self_attn.in_proj_weight",),
    ]
    keys_banned = []

# Only clip_g in this conversion script
arch_list = [ModelClipG]

def is_model_arch(model, state_dict):
    for key_tuple in model.keys_detect:
        if all(key in state_dict for key in key_tuple):
            # Optionally check for banned keys
            if any(key in state_dict for key in model.keys_banned):
                raise ValueError("Model architecture not allowed for conversion!")
            return True
    return False

def detect_arch(state_dict):
    for model in arch_list:
        if is_model_arch(model, state_dict):
            return model
    raise ValueError("Unknown model architecture!")

def parse_args():
    parser = argparse.ArgumentParser(description="Convert clip_g model (ComfyUI standard) to GGUF")
    parser.add_argument("--src", required=True, help="Source model file (.safetensors, .pt, etc)")
    parser.add_argument("--dst", help="Output GGUF file")
    return parser.parse_args()

def load_state_dict(path):
    if any(path.endswith(ext) for ext in [".ckpt", ".pt", ".bin", ".pth"]):
        state_dict = torch.load(path, map_location="cpu", weights_only=True)
        state_dict = state_dict.get("model", state_dict)
    else:
        state_dict = load_file(path)
    
    # Remove unwanted prefixes if they exist.
    prefix = None
    for pfx in ["model.diffusion_model.", "model."]:
        if any(k.startswith(pfx) for k in state_dict.keys()):
            prefix = pfx
            break
    new_state = {}
    for k, v in state_dict.items():
        if prefix:
            if not k.startswith(prefix):
                continue
            k = k.replace(prefix, "")
        new_state[k] = v
    return new_state

def load_model(path):
    state_dict = load_state_dict(path)
    model_arch = detect_arch(state_dict)
    print(f"Detected architecture: {model_arch.arch}")
    writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
    return writer, state_dict, model_arch

def handle_tensors(writer, state_dict, model_arch):
    # Check that all tensor names are within allowed length.
    for key in state_dict.keys():
        if len(key) > MAX_TENSOR_NAME_LENGTH:
            raise ValueError(f"Tensor name {key} exceeds maximum length {MAX_TENSOR_NAME_LENGTH}")

    for key, tensor in tqdm(state_dict.items(), desc="Processing tensors"):
        if isinstance(tensor, torch.Tensor):
            data = tensor.detach().cpu().numpy()
        else:
            data = tensor

        # Determine quantization based on key name
        key_lower = key.lower()
        if data.ndim == 1 or "bias" in key_lower or "layer_norm" in key_lower or "ln_" in key_lower:
            data_qtype = gguf.GGMLQuantizationType.F32
        elif "embeddings" in key_lower:
            data_qtype = gguf.GGMLQuantizationType.F32
        else:
            data_qtype = gguf.GGMLQuantizationType.F16

        if data.size <= QUANTIZATION_THRESHOLD:
            data_qtype = gguf.GGMLQuantizationType.F32

        try:
            quantized = gguf.quants.quantize(data, data_qtype)
        except Exception as e:
            tqdm.write(f"Quantization failed for {key} with error {e}; falling back to F16")
            data_qtype = gguf.GGMLQuantizationType.F16
            quantized = gguf.quants.quantize(data, data_qtype)

        writer.add_tensor(key, quantized, raw_dtype=data_qtype)
        tqdm.write(f"Processed {key}: {data.dtype} -> {data_qtype.name}, shape = {data.shape}")

def main():
    args = parse_args()
    writer, state_dict, model_arch = load_model(args.src)

    # Determine file type based on first tensor's dtype.
    first_tensor = next(iter(state_dict.values()))
    if first_tensor.dtype == torch.bfloat16:
        out_path = args.dst or os.path.splitext(args.src)[0] + "-BF16.gguf"
        writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16)
    else:
        out_path = args.dst or os.path.splitext(args.src)[0] + "-F16.gguf"
        writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)

    if os.path.isfile(out_path):
        input("Output exists. Press enter to continue or Ctrl+C to abort")

    handle_tensors(writer, state_dict, model_arch)
    writer.write_header_to_file(path=out_path)
    writer.write_kv_data_to_file()
    writer.write_tensors_to_file(progress=True)
    writer.close()

if __name__ == "__main__":
    main()