henry000 commited on
Commit
da24bd9
·
1 Parent(s): 12dfccf

✅ [Pass] loss function test, with new return shape

Browse files
Files changed (1) hide show
  1. tests/test_utils/test_loss.py +3 -2
tests/test_utils/test_loss.py CHANGED
@@ -26,14 +26,15 @@ def loss_function(cfg) -> YOLOLoss:
26
  @pytest.fixture
27
  def data():
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- targets = torch.zeros(20, 6, device=device)
30
  predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
31
  return predicts, targets
32
 
33
 
34
  def test_yolo_loss(loss_function, data):
35
  predicts, targets = data
36
- loss_iou, loss_dfl, loss_cls = loss_function(predicts, targets)
 
37
  assert torch.isnan(loss_iou)
38
  assert torch.isnan(loss_dfl)
39
  assert torch.isinf(loss_cls)
 
26
  @pytest.fixture
27
  def data():
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ targets = torch.zeros(1, 20, 5, device=device)
30
  predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
31
  return predicts, targets
32
 
33
 
34
  def test_yolo_loss(loss_function, data):
35
  predicts, targets = data
36
+ loss, (loss_iou, loss_dfl, loss_cls) = loss_function(predicts, targets)
37
+ assert torch.isnan(loss)
38
  assert torch.isnan(loss_iou)
39
  assert torch.isnan(loss_dfl)
40
  assert torch.isinf(loss_cls)