|
import torch |
|
from transformers import AutoModel, AutoTokenizer, FlaxAutoModel |
|
from datasets import load_dataset |
|
from wechsel import WECHSEL, load_embeddings |
|
|
|
source_tokenizer = AutoTokenizer.from_pretrained("roberta-large") |
|
model = AutoModel.from_pretrained("roberta-large") |
|
|
|
target_tokenizer = AutoTokenizer.from_pretrained("./") |
|
|
|
wechsel = WECHSEL( |
|
load_embeddings("en"), |
|
load_embeddings("fi"), |
|
bilingual_dictionary="finnish" |
|
) |
|
|
|
target_embeddings, info = wechsel.apply( |
|
source_tokenizer, |
|
target_tokenizer, |
|
model.get_input_embeddings().weight.detach().numpy(), |
|
) |
|
|
|
model.get_input_embeddings().weight.data = torch.from_numpy(target_embeddings) |
|
|
|
model.save_pretrained("./") |
|
|
|
flax_model = FlaxAutoModel.from_pretrained("./", from_pt=True) |
|
flax_model.save_pretrained("./") |
|
|