henry000 commited on
Commit
a80fd8c
·
1 Parent(s): fa2f0be

✅ [Pass] Test, mock dataset are 5 images

Browse files
tests/test_tools/test_data_loader.py CHANGED
@@ -42,15 +42,16 @@ def test_training_data_loader_correctness(train_dataloader: DataLoader):
42
 
43
  def test_validation_data_loader_correctness(validation_dataloader: DataLoader):
44
  batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
45
- assert batch_size == 4
46
- assert images.shape == (4, 3, 512, 768)
47
- assert targets.shape == (4, 18, 5)
48
- assert reverse_tensors.shape == (4, 5)
49
  expected_paths = [
50
- Path("tests/data/images/val/000000284106.jpg"),
51
  Path("tests/data/images/val/000000151480.jpg"),
52
- Path("tests/data/images/val/000000570456.jpg"),
53
  Path("tests/data/images/val/000000323571.jpg"),
 
 
54
  ]
55
  assert list(image_paths) == list(expected_paths)
56
 
 
42
 
43
  def test_validation_data_loader_correctness(validation_dataloader: DataLoader):
44
  batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
45
+ assert batch_size == 5
46
+ assert images.shape == (5, 3, 640, 640)
47
+ assert targets.shape == (5, 18, 5)
48
+ assert reverse_tensors.shape == (5, 5)
49
  expected_paths = [
 
50
  Path("tests/data/images/val/000000151480.jpg"),
51
+ Path("tests/data/images/val/000000284106.jpg"),
52
  Path("tests/data/images/val/000000323571.jpg"),
53
+ Path("tests/data/images/val/000000556498.jpg"),
54
+ Path("tests/data/images/val/000000570456.jpg"),
55
  ]
56
  assert list(image_paths) == list(expected_paths)
57
 
tests/test_tools/test_loss_functions.py CHANGED
@@ -51,7 +51,6 @@ def data():
51
  def test_yolo_loss(loss_function, data):
52
  predicts, targets = data
53
  loss, loss_dict = loss_function(predicts, predicts, targets)
54
- assert torch.isnan(loss)
55
- assert isnan(loss_dict["Loss/BoxLoss"])
56
- assert isnan(loss_dict["Loss/DFLLoss"])
57
- assert isinf(loss_dict["Loss/BCELoss"])
 
51
  def test_yolo_loss(loss_function, data):
52
  predicts, targets = data
53
  loss, loss_dict = loss_function(predicts, predicts, targets)
54
+ assert loss_dict["Loss/BoxLoss"] == 0
55
+ assert loss_dict["Loss/DFLLoss"] == 0
56
+ assert loss_dict["Loss/BCELoss"] >= 2e5