import unittest | |
import os | |
from image_classification_model.predict import predict | |
from image_classification_model.utils import ROOT_DIR | |
DATA_DIR = os.path.join(ROOT_DIR, "tests/data") | |
class TestPrediction(unittest.TestCase): | |
def test_prediction_label_3(self): | |
test_image_path = os.path.join(DATA_DIR, "number_3.jpg") | |
predicted_label = predict(test_image_path) | |
self.assertEqual( | |
predicted_label, 3, f"Expected label 3, but got {predicted_label}" | |
) | |
if __name__ == "__main__": | |
unittest.main() | |