import unittest import os from image_classification_model.predict import predict from image_classification_model.utils import ROOT_DIR DATA_DIR = os.path.join(ROOT_DIR, "tests/data") class TestPrediction(unittest.TestCase): def test_prediction_label_3(self): test_image_path = os.path.join(DATA_DIR, "number_3.jpg") predicted_label = predict(test_image_path) self.assertEqual( predicted_label, 3, f"Expected label 3, but got {predicted_label}" ) if __name__ == "__main__": unittest.main()