Safetensors
wav2vec2
mms
hausa-fine-tune-facebook-mms / create_adapters.py
Asakrg's picture
Upload create_adapters.py with huggingface_hub
f86be34 verified
#!/usr/bin/env python3
import torch
from safetensors.torch import save_file as safe_save_file
from transformers.models.wav2vec2.convert_wav2vec2_original_pytorch_checkpoint_to_pytorch import load_wav2vec2_layer
langs = ["afr", "amh", "ara", "asm", "ast", "azj-script_latin", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", "cmn-script_simplified", "cym", "dan", "deu", "ell", "eng", "est", "fas", "fin", "fra", "ful", "gle", "glg", "guj", "hau", "heb", "hin", "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khm", "kir", "kor", "lao", "lav", "lin", "lit", "ltz", "lug", "luo", "mal", "mar", "mkd", "mlt", "mon", "mri", "mya", "nld", "nob", "npi", "nso", "nya", "oci", "orm", "ory", "pan", "pol", "por", "pus", "ron", "rus", "slk", "slv", "sna", "snd", "som", "spa", "srp-script_latin", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "umb", "urd-script_arabic", "uzb-script_latin", "vie", "wol", "xho", "yor", "yue-script_traditional", "zlm", "zul"]
sd = torch.load("../mms1b_fl102.pt")
for lang in langs:
hf_dict = {}
fsq_adapters = sd["adapter"][lang]["model"]
for k, v in fsq_adapters.items():
renamed_adapters = load_wav2vec2_layer(k, v, hf_dict=hf_dict)
torch.save(hf_dict, f"./adapter.{lang}.bin")
safe_save_file(hf_dict, f"./adapter.{lang}.safetensors", metadata={"format": "pt"})