File size: 2,181 Bytes
ab8b628
 
 
 
 
 
 
49d0f3b
ab8b628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d0f3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from transformers import (
    ViTForImageClassification,
    ViTImageProcessor,
    TrainingArguments,
    Trainer,
)
from datasets import load_dataset
from .utils import ROOT_DIR


def train():
    # Load dataset
    dataset = load_dataset("mnist")
    dataset = dataset.rename_column("label", "labels")  # Critical rename

    # Reduce dataset size for faster training
    small_train_size = 2000  # Use only 2,000 training examples
    small_test_size = 500  # Use only 500 test examples

    dataset["train"] = dataset["train"].select(range(small_train_size))
    dataset["test"] = dataset["test"].select(range(small_test_size))

    # Initialize processor
    processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

    def transform(examples):
        # Convert grayscale to RGB and process
        images = [img.convert("RGB") for img in examples["image"]]
        inputs = processor(images=images, return_tensors="pt")
        inputs["labels"] = examples["labels"]
        return inputs

    # Apply preprocessing
    dataset.set_transform(transform)

    # Load model with proper initialization
    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224",
        num_labels=10,
        id2label={str(i): str(i) for i in range(10)},
        label2id={str(i): i for i in range(10)},
        ignore_mismatched_sizes=True,
    )

    # Training arguments with critical parameter
    training_args = TrainingArguments(
        output_dir="./results",
        remove_unused_columns=False,  # Preserve input data
        per_device_train_batch_size=16,  # Reduce batch size for efficiency
        eval_strategy="steps",
        num_train_epochs=3,
        fp16=False,  # Disable fp16 mixed precision
        save_steps=500,
        eval_steps=500,
        logging_steps=100,
        learning_rate=2e-4,
        push_to_hub=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
    )

    trainer.train()

    # Save model and processor
    model.save_pretrained(ROOT_DIR)
    processor.save_pretrained(ROOT_DIR)