|
from transformers import ( |
|
ViTForImageClassification, |
|
ViTImageProcessor, |
|
TrainingArguments, |
|
Trainer, |
|
) |
|
from datasets import load_dataset |
|
from .utils import ROOT_DIR |
|
|
|
|
|
def train(): |
|
|
|
dataset = load_dataset("mnist") |
|
dataset = dataset.rename_column("label", "labels") |
|
|
|
|
|
small_train_size = 2000 |
|
small_test_size = 500 |
|
|
|
dataset["train"] = dataset["train"].select(range(small_train_size)) |
|
dataset["test"] = dataset["test"].select(range(small_test_size)) |
|
|
|
|
|
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") |
|
|
|
def transform(examples): |
|
|
|
images = [img.convert("RGB") for img in examples["image"]] |
|
inputs = processor(images=images, return_tensors="pt") |
|
inputs["labels"] = examples["labels"] |
|
return inputs |
|
|
|
|
|
dataset.set_transform(transform) |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained( |
|
"google/vit-base-patch16-224", |
|
num_labels=10, |
|
id2label={str(i): str(i) for i in range(10)}, |
|
label2id={str(i): i for i in range(10)}, |
|
ignore_mismatched_sizes=True, |
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
remove_unused_columns=False, |
|
per_device_train_batch_size=16, |
|
eval_strategy="steps", |
|
num_train_epochs=3, |
|
fp16=False, |
|
save_steps=500, |
|
eval_steps=500, |
|
logging_steps=100, |
|
learning_rate=2e-4, |
|
push_to_hub=False, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["test"], |
|
) |
|
|
|
trainer.train() |
|
|
|
|
|
model.save_pretrained(ROOT_DIR) |
|
processor.save_pretrained(ROOT_DIR) |
|
|