File size: 2,181 Bytes
ab8b628 49d0f3b ab8b628 49d0f3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from transformers import (
ViTForImageClassification,
ViTImageProcessor,
TrainingArguments,
Trainer,
)
from datasets import load_dataset
from .utils import ROOT_DIR
def train():
# Load dataset
dataset = load_dataset("mnist")
dataset = dataset.rename_column("label", "labels") # Critical rename
# Reduce dataset size for faster training
small_train_size = 2000 # Use only 2,000 training examples
small_test_size = 500 # Use only 500 test examples
dataset["train"] = dataset["train"].select(range(small_train_size))
dataset["test"] = dataset["test"].select(range(small_test_size))
# Initialize processor
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
def transform(examples):
# Convert grayscale to RGB and process
images = [img.convert("RGB") for img in examples["image"]]
inputs = processor(images=images, return_tensors="pt")
inputs["labels"] = examples["labels"]
return inputs
# Apply preprocessing
dataset.set_transform(transform)
# Load model with proper initialization
model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
num_labels=10,
id2label={str(i): str(i) for i in range(10)},
label2id={str(i): i for i in range(10)},
ignore_mismatched_sizes=True,
)
# Training arguments with critical parameter
training_args = TrainingArguments(
output_dir="./results",
remove_unused_columns=False, # Preserve input data
per_device_train_batch_size=16, # Reduce batch size for efficiency
eval_strategy="steps",
num_train_epochs=3,
fp16=False, # Disable fp16 mixed precision
save_steps=500,
eval_steps=500,
logging_steps=100,
learning_rate=2e-4,
push_to_hub=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
trainer.train()
# Save model and processor
model.save_pretrained(ROOT_DIR)
processor.save_pretrained(ROOT_DIR)
|