from transformers import ( ViTForImageClassification, ViTImageProcessor, TrainingArguments, Trainer, ) from datasets import load_dataset from .utils import ROOT_DIR def train(): # Load dataset dataset = load_dataset("mnist") dataset = dataset.rename_column("label", "labels") # Critical rename # Reduce dataset size for faster training small_train_size = 2000 # Use only 2,000 training examples small_test_size = 500 # Use only 500 test examples dataset["train"] = dataset["train"].select(range(small_train_size)) dataset["test"] = dataset["test"].select(range(small_test_size)) # Initialize processor processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") def transform(examples): # Convert grayscale to RGB and process images = [img.convert("RGB") for img in examples["image"]] inputs = processor(images=images, return_tensors="pt") inputs["labels"] = examples["labels"] return inputs # Apply preprocessing dataset.set_transform(transform) # Load model with proper initialization model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224", num_labels=10, id2label={str(i): str(i) for i in range(10)}, label2id={str(i): i for i in range(10)}, ignore_mismatched_sizes=True, ) # Training arguments with critical parameter training_args = TrainingArguments( output_dir="./results", remove_unused_columns=False, # Preserve input data per_device_train_batch_size=16, # Reduce batch size for efficiency eval_strategy="steps", num_train_epochs=3, fp16=False, # Disable fp16 mixed precision save_steps=500, eval_steps=500, logging_steps=100, learning_rate=2e-4, push_to_hub=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], ) trainer.train() # Save model and processor model.save_pretrained(ROOT_DIR) processor.save_pretrained(ROOT_DIR)