Spaces:
Runtime error
Runtime error
from transformers import ( | |
AutoModelForImageClassification, | |
AutoImageProcessor, | |
TrainingArguments, | |
Trainer, | |
) | |
from datasets import load_dataset | |
import os | |
def train(): | |
# Load dataset | |
dataset = load_dataset("ylecun/mnist") | |
# Load processor and apply preprocessing to the dataset | |
processor = AutoImageProcessor.from_pretrained("SupremoUGH/image-classification-model") | |
def process(examples): | |
images = [img.convert("RGB") for img in examples["image"]] | |
inputs = processor(images=images, return_tensors="pt") | |
inputs["labels"] = examples["label"] | |
return inputs | |
dataset.set_transform(process) # Sometimes `map` instead of `set_transform` | |
# Load model and train it with certain training arguments | |
model = AutoModelForImageClassification.from_pretrained("SupremoUGH/image-classification-model") | |
training_args = TrainingArguments( | |
output_dir="./results", | |
remove_unused_columns=False, # Preserve input data | |
per_device_train_batch_size=16, # Reduce batch size for efficiency | |
eval_strategy="steps", | |
num_train_epochs=3, | |
fp16=False, # Disable fp16 mixed precision | |
save_steps=500, | |
eval_steps=500, | |
logging_steps=100, | |
learning_rate=2e-4, | |
push_to_hub=False, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["test"], # Sometimes called "validation" | |
) | |
trainer.train() | |
# Save fine-tuned model | |
save_dir = "./saved_model" | |
os.makedirs(save_dir, exist_ok=True) | |
model.save_pretrained(save_dir) | |
print(f"Model saved to {save_dir}") | |
if __name__ == "__main__": | |
train() |