SupremoUGH commited on
Commit
ab8b628
·
unverified ·
1 Parent(s): 0e6c91c

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Python metadata
2
+ image_classification_model.egg-info/
3
+ venv/
4
+ __pycache__/
5
+
6
+ # Model
7
+ results/
8
+ model/
README.md CHANGED
@@ -1,3 +1,29 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ # Image Classification Model (ViT)
5
+
6
+ This is an image classification model based on **Vision Transformer (ViT)**, fine-tuned on the **MNIST** dataset. The model is designed to classify images into one of 10 possible classes (digits 0-9). The code is compatible with Hugging Face's inference providers and can be easily deployed.
7
+
8
+ ## Model Details
9
+
10
+ - **Model Type**: Vision Transformer (ViT)
11
+ - **Base Model**: `google/vit-base-patch16-224`
12
+ - **Task**: Image Classification
13
+ - **Dataset**: MNIST (handwritten digits)
14
+ - **Labels**: 10 classes (0-9)
15
+
16
+ ## How to Use
17
+
18
+ ### Install Requirements
19
+
20
+ Make sure you have the following dependencies installed:
21
+
22
+ ```bash
23
+ pip3 install requirements.txt
24
+ ```
25
+
26
+ ### Run unit tests
27
+ ```bash
28
+ python3 -m unittest discover -s tests
29
+ ```
bin/train.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from image_classification_model.train import train
2
+
3
+ if __name__ == "__main__":
4
+ train()
model/config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/vit-base-patch16-224",
3
+ "architectures": [
4
+ "ViTForImageClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "encoder_stride": 16,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.0,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "0",
13
+ "1": "1",
14
+ "2": "2",
15
+ "3": "3",
16
+ "4": "4",
17
+ "5": "5",
18
+ "6": "6",
19
+ "7": "7",
20
+ "8": "8",
21
+ "9": "9"
22
+ },
23
+ "image_size": 224,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 3072,
26
+ "label2id": {
27
+ "0": 0,
28
+ "1": 1,
29
+ "2": 2,
30
+ "3": 3,
31
+ "4": 4,
32
+ "5": 5,
33
+ "6": 6,
34
+ "7": 7,
35
+ "8": 8,
36
+ "9": 9
37
+ },
38
+ "layer_norm_eps": 1e-12,
39
+ "model_type": "vit",
40
+ "num_attention_heads": 12,
41
+ "num_channels": 3,
42
+ "num_hidden_layers": 12,
43
+ "patch_size": 16,
44
+ "problem_type": "single_label_classification",
45
+ "qkv_bias": true,
46
+ "torch_dtype": "float32",
47
+ "transformers_version": "4.48.2"
48
+ }
model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02e9496d47c1edd9c65ffbb29c62f2238aa998d10c73ef901252524d5f619d7c
3
+ size 343248584
model/preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "ViTImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 224,
21
+ "width": 224
22
+ }
23
+ }
predict.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from image_classification_model.predict import predict
3
+
4
+
5
+ def main():
6
+ if len(sys.argv) < 2:
7
+ print("Usage: python predict.py <image_path>")
8
+ sys.exit(1)
9
+
10
+ image_path = sys.argv[1]
11
+
12
+ # Run prediction (handles preprocessing internally)
13
+ predicted_label = predict(image_path)
14
+
15
+ # Print output in Hugging Face-compatible format
16
+ print({"label": predicted_label})
17
+
18
+
19
+ if __name__ == "__main__":
20
+ main()
pyproject.toml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "image-classification-model"
3
+ version = "0.1.0"
4
+ description = "MNIST-compatible image classification model"
5
+ requires-python = ">=3.8"
6
+ dependencies = [
7
+ "torch>=2.0.0",
8
+ "transformers>=4.30.0",
9
+ "Pillow>=9.0.0",
10
+ "datasets>=2.0.0",
11
+ "accelerate>=0.26.0"
12
+ ]
13
+
14
+ [build-system]
15
+ requires = ["setuptools>=65.5.1", "wheel"]
16
+ build-backend = "setuptools.build_meta"
17
+
18
+
19
+ [tool.setuptools]
20
+ packages = ["image_classification_model"]
21
+
22
+ [tool.setuptools.package-dir]
23
+ image_classification_model = "src"
24
+
25
+ [tool.setuptools.entry_points]
26
+ console_scripts = [
27
+ "train = bin.train:main",
28
+ "subtract = bin.subtract:main"
29
+ ]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ Pillow>=9.0.0
4
+ datasets>=2.0.0
5
+ accelerate>=0.26.0
src/__init__.py ADDED
File without changes
src/predict.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from .preprocess import preprocess_image
4
+ from .utils import load_model
5
+
6
+
7
+ def predict_with_model(model, inputs):
8
+ """Runs inference and returns the predicted class."""
9
+ model.eval() # Ensure the model is in evaluation mode
10
+ with torch.no_grad(): # Disable gradient calculation
11
+ outputs = model(**inputs)
12
+ logits = outputs.logits
13
+ predicted_class = logits.argmax(dim=-1).item() # Get predicted class index
14
+ return predicted_class
15
+
16
+
17
+ def predict(image_path):
18
+ """Loads an image, preprocesses it, runs the model, and returns the prediction."""
19
+ image = Image.open(image_path).convert("RGB")
20
+ inputs = preprocess_image(image)
21
+
22
+ # Load model
23
+ model = load_model()
24
+
25
+ # Ensure inputs are on the same device as the model
26
+ device = model.device
27
+ inputs = {key: tensor.to(device) for key, tensor in inputs.items()}
28
+
29
+ return predict_with_model(model, inputs)
src/preprocess.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTImageProcessor
2
+ from .utils import MODEL_DIR
3
+
4
+ processor = ViTImageProcessor.from_pretrained(MODEL_DIR)
5
+
6
+
7
+ def preprocess_image(image):
8
+ """Preprocesses a single image for ViT inference."""
9
+ inputs = processor(images=image, return_tensors="pt")
10
+ return inputs
src/train.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ ViTForImageClassification,
3
+ ViTImageProcessor,
4
+ TrainingArguments,
5
+ Trainer,
6
+ )
7
+ from datasets import load_dataset
8
+ from .utils import MODEL_DIR
9
+
10
+
11
+ def train():
12
+ # Load dataset
13
+ dataset = load_dataset("mnist")
14
+ dataset = dataset.rename_column("label", "labels") # Critical rename
15
+
16
+ # Reduce dataset size for faster training
17
+ small_train_size = 2000 # Use only 2,000 training examples
18
+ small_test_size = 500 # Use only 500 test examples
19
+
20
+ dataset["train"] = dataset["train"].select(range(small_train_size))
21
+ dataset["test"] = dataset["test"].select(range(small_test_size))
22
+
23
+ # Initialize processor
24
+ processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
25
+
26
+ def transform(examples):
27
+ # Convert grayscale to RGB and process
28
+ images = [img.convert("RGB") for img in examples["image"]]
29
+ inputs = processor(images=images, return_tensors="pt")
30
+ inputs["labels"] = examples["labels"]
31
+ return inputs
32
+
33
+ # Apply preprocessing
34
+ dataset.set_transform(transform)
35
+
36
+ # Load model with proper initialization
37
+ model = ViTForImageClassification.from_pretrained(
38
+ "google/vit-base-patch16-224",
39
+ num_labels=10,
40
+ id2label={str(i): str(i) for i in range(10)},
41
+ label2id={str(i): i for i in range(10)},
42
+ ignore_mismatched_sizes=True,
43
+ )
44
+
45
+ # Training arguments with critical parameter
46
+ training_args = TrainingArguments(
47
+ output_dir="./results",
48
+ remove_unused_columns=False, # Preserve input data
49
+ per_device_train_batch_size=16, # Reduce batch size for efficiency
50
+ eval_strategy="steps",
51
+ num_train_epochs=3,
52
+ fp16=False, # Disable fp16 mixed precision
53
+ save_steps=500,
54
+ eval_steps=500,
55
+ logging_steps=100,
56
+ learning_rate=2e-4,
57
+ push_to_hub=False,
58
+ )
59
+
60
+ trainer = Trainer(
61
+ model=model,
62
+ args=training_args,
63
+ train_dataset=dataset["train"],
64
+ eval_dataset=dataset["test"],
65
+ )
66
+
67
+ trainer.train()
68
+
69
+ # Save model and processor
70
+ model.save_pretrained(MODEL_DIR)
71
+ processor.save_pretrained(MODEL_DIR)
src/utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTForImageClassification
2
+ import os
3
+
4
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
5
+ MODEL_DIR = os.path.join(ROOT_DIR, "model")
6
+
7
+
8
+ def load_model():
9
+ return ViTForImageClassification.from_pretrained(MODEL_DIR)
tests/__init__.py ADDED
File without changes
tests/data/number_3.jpg ADDED
tests/test_prediction.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import os
3
+ from image_classification_model.predict import predict
4
+ from image_classification_model.utils import ROOT_DIR
5
+
6
+ DATA_DIR = os.path.join(ROOT_DIR, "tests/data")
7
+
8
+
9
+ class TestPrediction(unittest.TestCase):
10
+ def test_prediction_label_3(self):
11
+ test_image_path = os.path.join(DATA_DIR, "number_3.jpg")
12
+ predicted_label = predict(test_image_path)
13
+ self.assertEqual(
14
+ predicted_label, 3, f"Expected label 3, but got {predicted_label}"
15
+ )
16
+
17
+
18
+ if __name__ == "__main__":
19
+ unittest.main()