first commit
Browse files- .gitignore +8 -0
- README.md +29 -3
- bin/train.py +4 -0
- model/config.json +48 -0
- model/model.safetensors +3 -0
- model/preprocessor_config.json +23 -0
- predict.py +20 -0
- pyproject.toml +29 -0
- requirements.txt +5 -0
- src/__init__.py +0 -0
- src/predict.py +29 -0
- src/preprocess.py +10 -0
- src/train.py +71 -0
- src/utils.py +9 -0
- tests/__init__.py +0 -0
- tests/data/number_3.jpg +0 -0
- tests/test_prediction.py +19 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python metadata
|
2 |
+
image_classification_model.egg-info/
|
3 |
+
venv/
|
4 |
+
__pycache__/
|
5 |
+
|
6 |
+
# Model
|
7 |
+
results/
|
8 |
+
model/
|
README.md
CHANGED
@@ -1,3 +1,29 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
---
|
4 |
+
# Image Classification Model (ViT)
|
5 |
+
|
6 |
+
This is an image classification model based on **Vision Transformer (ViT)**, fine-tuned on the **MNIST** dataset. The model is designed to classify images into one of 10 possible classes (digits 0-9). The code is compatible with Hugging Face's inference providers and can be easily deployed.
|
7 |
+
|
8 |
+
## Model Details
|
9 |
+
|
10 |
+
- **Model Type**: Vision Transformer (ViT)
|
11 |
+
- **Base Model**: `google/vit-base-patch16-224`
|
12 |
+
- **Task**: Image Classification
|
13 |
+
- **Dataset**: MNIST (handwritten digits)
|
14 |
+
- **Labels**: 10 classes (0-9)
|
15 |
+
|
16 |
+
## How to Use
|
17 |
+
|
18 |
+
### Install Requirements
|
19 |
+
|
20 |
+
Make sure you have the following dependencies installed:
|
21 |
+
|
22 |
+
```bash
|
23 |
+
pip3 install requirements.txt
|
24 |
+
```
|
25 |
+
|
26 |
+
### Run unit tests
|
27 |
+
```bash
|
28 |
+
python3 -m unittest discover -s tests
|
29 |
+
```
|
bin/train.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from image_classification_model.train import train
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
train()
|
model/config.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "google/vit-base-patch16-224",
|
3 |
+
"architectures": [
|
4 |
+
"ViTForImageClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.0,
|
7 |
+
"encoder_stride": 16,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.0,
|
10 |
+
"hidden_size": 768,
|
11 |
+
"id2label": {
|
12 |
+
"0": "0",
|
13 |
+
"1": "1",
|
14 |
+
"2": "2",
|
15 |
+
"3": "3",
|
16 |
+
"4": "4",
|
17 |
+
"5": "5",
|
18 |
+
"6": "6",
|
19 |
+
"7": "7",
|
20 |
+
"8": "8",
|
21 |
+
"9": "9"
|
22 |
+
},
|
23 |
+
"image_size": 224,
|
24 |
+
"initializer_range": 0.02,
|
25 |
+
"intermediate_size": 3072,
|
26 |
+
"label2id": {
|
27 |
+
"0": 0,
|
28 |
+
"1": 1,
|
29 |
+
"2": 2,
|
30 |
+
"3": 3,
|
31 |
+
"4": 4,
|
32 |
+
"5": 5,
|
33 |
+
"6": 6,
|
34 |
+
"7": 7,
|
35 |
+
"8": 8,
|
36 |
+
"9": 9
|
37 |
+
},
|
38 |
+
"layer_norm_eps": 1e-12,
|
39 |
+
"model_type": "vit",
|
40 |
+
"num_attention_heads": 12,
|
41 |
+
"num_channels": 3,
|
42 |
+
"num_hidden_layers": 12,
|
43 |
+
"patch_size": 16,
|
44 |
+
"problem_type": "single_label_classification",
|
45 |
+
"qkv_bias": true,
|
46 |
+
"torch_dtype": "float32",
|
47 |
+
"transformers_version": "4.48.2"
|
48 |
+
}
|
model/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02e9496d47c1edd9c65ffbb29c62f2238aa998d10c73ef901252524d5f619d7c
|
3 |
+
size 343248584
|
model/preprocessor_config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_convert_rgb": null,
|
3 |
+
"do_normalize": true,
|
4 |
+
"do_rescale": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"image_mean": [
|
7 |
+
0.5,
|
8 |
+
0.5,
|
9 |
+
0.5
|
10 |
+
],
|
11 |
+
"image_processor_type": "ViTImageProcessor",
|
12 |
+
"image_std": [
|
13 |
+
0.5,
|
14 |
+
0.5,
|
15 |
+
0.5
|
16 |
+
],
|
17 |
+
"resample": 2,
|
18 |
+
"rescale_factor": 0.00392156862745098,
|
19 |
+
"size": {
|
20 |
+
"height": 224,
|
21 |
+
"width": 224
|
22 |
+
}
|
23 |
+
}
|
predict.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from image_classification_model.predict import predict
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
if len(sys.argv) < 2:
|
7 |
+
print("Usage: python predict.py <image_path>")
|
8 |
+
sys.exit(1)
|
9 |
+
|
10 |
+
image_path = sys.argv[1]
|
11 |
+
|
12 |
+
# Run prediction (handles preprocessing internally)
|
13 |
+
predicted_label = predict(image_path)
|
14 |
+
|
15 |
+
# Print output in Hugging Face-compatible format
|
16 |
+
print({"label": predicted_label})
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
main()
|
pyproject.toml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "image-classification-model"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "MNIST-compatible image classification model"
|
5 |
+
requires-python = ">=3.8"
|
6 |
+
dependencies = [
|
7 |
+
"torch>=2.0.0",
|
8 |
+
"transformers>=4.30.0",
|
9 |
+
"Pillow>=9.0.0",
|
10 |
+
"datasets>=2.0.0",
|
11 |
+
"accelerate>=0.26.0"
|
12 |
+
]
|
13 |
+
|
14 |
+
[build-system]
|
15 |
+
requires = ["setuptools>=65.5.1", "wheel"]
|
16 |
+
build-backend = "setuptools.build_meta"
|
17 |
+
|
18 |
+
|
19 |
+
[tool.setuptools]
|
20 |
+
packages = ["image_classification_model"]
|
21 |
+
|
22 |
+
[tool.setuptools.package-dir]
|
23 |
+
image_classification_model = "src"
|
24 |
+
|
25 |
+
[tool.setuptools.entry_points]
|
26 |
+
console_scripts = [
|
27 |
+
"train = bin.train:main",
|
28 |
+
"subtract = bin.subtract:main"
|
29 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.30.0
|
3 |
+
Pillow>=9.0.0
|
4 |
+
datasets>=2.0.0
|
5 |
+
accelerate>=0.26.0
|
src/__init__.py
ADDED
File without changes
|
src/predict.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from .preprocess import preprocess_image
|
4 |
+
from .utils import load_model
|
5 |
+
|
6 |
+
|
7 |
+
def predict_with_model(model, inputs):
|
8 |
+
"""Runs inference and returns the predicted class."""
|
9 |
+
model.eval() # Ensure the model is in evaluation mode
|
10 |
+
with torch.no_grad(): # Disable gradient calculation
|
11 |
+
outputs = model(**inputs)
|
12 |
+
logits = outputs.logits
|
13 |
+
predicted_class = logits.argmax(dim=-1).item() # Get predicted class index
|
14 |
+
return predicted_class
|
15 |
+
|
16 |
+
|
17 |
+
def predict(image_path):
|
18 |
+
"""Loads an image, preprocesses it, runs the model, and returns the prediction."""
|
19 |
+
image = Image.open(image_path).convert("RGB")
|
20 |
+
inputs = preprocess_image(image)
|
21 |
+
|
22 |
+
# Load model
|
23 |
+
model = load_model()
|
24 |
+
|
25 |
+
# Ensure inputs are on the same device as the model
|
26 |
+
device = model.device
|
27 |
+
inputs = {key: tensor.to(device) for key, tensor in inputs.items()}
|
28 |
+
|
29 |
+
return predict_with_model(model, inputs)
|
src/preprocess.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ViTImageProcessor
|
2 |
+
from .utils import MODEL_DIR
|
3 |
+
|
4 |
+
processor = ViTImageProcessor.from_pretrained(MODEL_DIR)
|
5 |
+
|
6 |
+
|
7 |
+
def preprocess_image(image):
|
8 |
+
"""Preprocesses a single image for ViT inference."""
|
9 |
+
inputs = processor(images=image, return_tensors="pt")
|
10 |
+
return inputs
|
src/train.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import (
|
2 |
+
ViTForImageClassification,
|
3 |
+
ViTImageProcessor,
|
4 |
+
TrainingArguments,
|
5 |
+
Trainer,
|
6 |
+
)
|
7 |
+
from datasets import load_dataset
|
8 |
+
from .utils import MODEL_DIR
|
9 |
+
|
10 |
+
|
11 |
+
def train():
|
12 |
+
# Load dataset
|
13 |
+
dataset = load_dataset("mnist")
|
14 |
+
dataset = dataset.rename_column("label", "labels") # Critical rename
|
15 |
+
|
16 |
+
# Reduce dataset size for faster training
|
17 |
+
small_train_size = 2000 # Use only 2,000 training examples
|
18 |
+
small_test_size = 500 # Use only 500 test examples
|
19 |
+
|
20 |
+
dataset["train"] = dataset["train"].select(range(small_train_size))
|
21 |
+
dataset["test"] = dataset["test"].select(range(small_test_size))
|
22 |
+
|
23 |
+
# Initialize processor
|
24 |
+
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
25 |
+
|
26 |
+
def transform(examples):
|
27 |
+
# Convert grayscale to RGB and process
|
28 |
+
images = [img.convert("RGB") for img in examples["image"]]
|
29 |
+
inputs = processor(images=images, return_tensors="pt")
|
30 |
+
inputs["labels"] = examples["labels"]
|
31 |
+
return inputs
|
32 |
+
|
33 |
+
# Apply preprocessing
|
34 |
+
dataset.set_transform(transform)
|
35 |
+
|
36 |
+
# Load model with proper initialization
|
37 |
+
model = ViTForImageClassification.from_pretrained(
|
38 |
+
"google/vit-base-patch16-224",
|
39 |
+
num_labels=10,
|
40 |
+
id2label={str(i): str(i) for i in range(10)},
|
41 |
+
label2id={str(i): i for i in range(10)},
|
42 |
+
ignore_mismatched_sizes=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
# Training arguments with critical parameter
|
46 |
+
training_args = TrainingArguments(
|
47 |
+
output_dir="./results",
|
48 |
+
remove_unused_columns=False, # Preserve input data
|
49 |
+
per_device_train_batch_size=16, # Reduce batch size for efficiency
|
50 |
+
eval_strategy="steps",
|
51 |
+
num_train_epochs=3,
|
52 |
+
fp16=False, # Disable fp16 mixed precision
|
53 |
+
save_steps=500,
|
54 |
+
eval_steps=500,
|
55 |
+
logging_steps=100,
|
56 |
+
learning_rate=2e-4,
|
57 |
+
push_to_hub=False,
|
58 |
+
)
|
59 |
+
|
60 |
+
trainer = Trainer(
|
61 |
+
model=model,
|
62 |
+
args=training_args,
|
63 |
+
train_dataset=dataset["train"],
|
64 |
+
eval_dataset=dataset["test"],
|
65 |
+
)
|
66 |
+
|
67 |
+
trainer.train()
|
68 |
+
|
69 |
+
# Save model and processor
|
70 |
+
model.save_pretrained(MODEL_DIR)
|
71 |
+
processor.save_pretrained(MODEL_DIR)
|
src/utils.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ViTForImageClassification
|
2 |
+
import os
|
3 |
+
|
4 |
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
5 |
+
MODEL_DIR = os.path.join(ROOT_DIR, "model")
|
6 |
+
|
7 |
+
|
8 |
+
def load_model():
|
9 |
+
return ViTForImageClassification.from_pretrained(MODEL_DIR)
|
tests/__init__.py
ADDED
File without changes
|
tests/data/number_3.jpg
ADDED
![]() |
tests/test_prediction.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import os
|
3 |
+
from image_classification_model.predict import predict
|
4 |
+
from image_classification_model.utils import ROOT_DIR
|
5 |
+
|
6 |
+
DATA_DIR = os.path.join(ROOT_DIR, "tests/data")
|
7 |
+
|
8 |
+
|
9 |
+
class TestPrediction(unittest.TestCase):
|
10 |
+
def test_prediction_label_3(self):
|
11 |
+
test_image_path = os.path.join(DATA_DIR, "number_3.jpg")
|
12 |
+
predicted_label = predict(test_image_path)
|
13 |
+
self.assertEqual(
|
14 |
+
predicted_label, 3, f"Expected label 3, but got {predicted_label}"
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
unittest.main()
|