image-classification-model / tests /test_prediction.py
SupremoUGH's picture
first commit
ab8b628 unverified
raw
history blame contribute delete
546 Bytes
import unittest
import os
from image_classification_model.predict import predict
from image_classification_model.utils import ROOT_DIR
DATA_DIR = os.path.join(ROOT_DIR, "tests/data")
class TestPrediction(unittest.TestCase):
def test_prediction_label_3(self):
test_image_path = os.path.join(DATA_DIR, "number_3.jpg")
predicted_label = predict(test_image_path)
self.assertEqual(
predicted_label, 3, f"Expected label 3, but got {predicted_label}"
)
if __name__ == "__main__":
unittest.main()