✅ [Pass] loss function test, with new return shape
Browse files
tests/test_utils/test_loss.py
CHANGED
@@ -26,14 +26,15 @@ def loss_function(cfg) -> YOLOLoss:
|
|
26 |
@pytest.fixture
|
27 |
def data():
|
28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
-
targets = torch.zeros(20,
|
30 |
predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
|
31 |
return predicts, targets
|
32 |
|
33 |
|
34 |
def test_yolo_loss(loss_function, data):
|
35 |
predicts, targets = data
|
36 |
-
loss_iou, loss_dfl, loss_cls = loss_function(predicts, targets)
|
|
|
37 |
assert torch.isnan(loss_iou)
|
38 |
assert torch.isnan(loss_dfl)
|
39 |
assert torch.isinf(loss_cls)
|
|
|
26 |
@pytest.fixture
|
27 |
def data():
|
28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
targets = torch.zeros(1, 20, 5, device=device)
|
30 |
predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
|
31 |
return predicts, targets
|
32 |
|
33 |
|
34 |
def test_yolo_loss(loss_function, data):
|
35 |
predicts, targets = data
|
36 |
+
loss, (loss_iou, loss_dfl, loss_cls) = loss_function(predicts, targets)
|
37 |
+
assert torch.isnan(loss)
|
38 |
assert torch.isnan(loss_iou)
|
39 |
assert torch.isnan(loss_dfl)
|
40 |
assert torch.isinf(loss_cls)
|