|
|
|
import os |
|
import torch |
|
import gguf |
|
import argparse |
|
from tqdm import tqdm |
|
from safetensors.torch import load_file |
|
|
|
|
|
QUANTIZATION_THRESHOLD = 1024 |
|
MAX_TENSOR_NAME_LENGTH = 127 |
|
|
|
|
|
class ModelTemplate: |
|
arch = "invalid" |
|
shape_fix = False |
|
keys_detect = [] |
|
keys_banned = [] |
|
|
|
|
|
class ModelClipG(ModelTemplate): |
|
arch = "clip_g" |
|
shape_fix = False |
|
keys_detect = [ |
|
("logit_scale",), |
|
("text_model.embeddings.position_embedding.weight",), |
|
("text_model.encoder.layers.0.self_attn.in_proj_weight",), |
|
] |
|
keys_banned = [] |
|
|
|
|
|
arch_list = [ModelClipG] |
|
|
|
def is_model_arch(model, state_dict): |
|
for key_tuple in model.keys_detect: |
|
if all(key in state_dict for key in key_tuple): |
|
|
|
if any(key in state_dict for key in model.keys_banned): |
|
raise ValueError("Model architecture not allowed for conversion!") |
|
return True |
|
return False |
|
|
|
def detect_arch(state_dict): |
|
for model in arch_list: |
|
if is_model_arch(model, state_dict): |
|
return model |
|
raise ValueError("Unknown model architecture!") |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Convert clip_g model (ComfyUI standard) to GGUF") |
|
parser.add_argument("--src", required=True, help="Source model file (.safetensors, .pt, etc)") |
|
parser.add_argument("--dst", help="Output GGUF file") |
|
return parser.parse_args() |
|
|
|
def load_state_dict(path): |
|
if any(path.endswith(ext) for ext in [".ckpt", ".pt", ".bin", ".pth"]): |
|
state_dict = torch.load(path, map_location="cpu", weights_only=True) |
|
state_dict = state_dict.get("model", state_dict) |
|
else: |
|
state_dict = load_file(path) |
|
|
|
|
|
prefix = None |
|
for pfx in ["model.diffusion_model.", "model."]: |
|
if any(k.startswith(pfx) for k in state_dict.keys()): |
|
prefix = pfx |
|
break |
|
new_state = {} |
|
for k, v in state_dict.items(): |
|
if prefix: |
|
if not k.startswith(prefix): |
|
continue |
|
k = k.replace(prefix, "") |
|
new_state[k] = v |
|
return new_state |
|
|
|
def load_model(path): |
|
state_dict = load_state_dict(path) |
|
model_arch = detect_arch(state_dict) |
|
print(f"Detected architecture: {model_arch.arch}") |
|
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch) |
|
return writer, state_dict, model_arch |
|
|
|
def handle_tensors(writer, state_dict, model_arch): |
|
|
|
for key in state_dict.keys(): |
|
if len(key) > MAX_TENSOR_NAME_LENGTH: |
|
raise ValueError(f"Tensor name {key} exceeds maximum length {MAX_TENSOR_NAME_LENGTH}") |
|
|
|
for key, tensor in tqdm(state_dict.items(), desc="Processing tensors"): |
|
if isinstance(tensor, torch.Tensor): |
|
data = tensor.detach().cpu().numpy() |
|
else: |
|
data = tensor |
|
|
|
|
|
key_lower = key.lower() |
|
if data.ndim == 1 or "bias" in key_lower or "layer_norm" in key_lower or "ln_" in key_lower: |
|
data_qtype = gguf.GGMLQuantizationType.F32 |
|
elif "embeddings" in key_lower: |
|
data_qtype = gguf.GGMLQuantizationType.F32 |
|
else: |
|
data_qtype = gguf.GGMLQuantizationType.F16 |
|
|
|
if data.size <= QUANTIZATION_THRESHOLD: |
|
data_qtype = gguf.GGMLQuantizationType.F32 |
|
|
|
try: |
|
quantized = gguf.quants.quantize(data, data_qtype) |
|
except Exception as e: |
|
tqdm.write(f"Quantization failed for {key} with error {e}; falling back to F16") |
|
data_qtype = gguf.GGMLQuantizationType.F16 |
|
quantized = gguf.quants.quantize(data, data_qtype) |
|
|
|
writer.add_tensor(key, quantized, raw_dtype=data_qtype) |
|
tqdm.write(f"Processed {key}: {data.dtype} -> {data_qtype.name}, shape = {data.shape}") |
|
|
|
def main(): |
|
args = parse_args() |
|
writer, state_dict, model_arch = load_model(args.src) |
|
|
|
|
|
first_tensor = next(iter(state_dict.values())) |
|
if first_tensor.dtype == torch.bfloat16: |
|
out_path = args.dst or os.path.splitext(args.src)[0] + "-BF16.gguf" |
|
writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16) |
|
else: |
|
out_path = args.dst or os.path.splitext(args.src)[0] + "-F16.gguf" |
|
writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16) |
|
|
|
if os.path.isfile(out_path): |
|
input("Output exists. Press enter to continue or Ctrl+C to abort") |
|
|
|
handle_tensors(writer, state_dict, model_arch) |
|
writer.write_header_to_file(path=out_path) |
|
writer.write_kv_data_to_file() |
|
writer.write_tensors_to_file(progress=True) |
|
writer.close() |
|
|
|
if __name__ == "__main__": |
|
main() |