|
import torch
|
|
|
|
from metrics.utils import haversine, reverse
|
|
|
|
from torchmetrics import Metric
|
|
|
|
|
|
class HaversineMetrics(Metric):
|
|
"""
|
|
Computes the average haversine distance between the predicted and ground truth points.
|
|
Compute the accuracy given some radiuses.
|
|
Compute the Geoguessr score given some radiuses.
|
|
|
|
Args:
|
|
acc_radiuses (list): list of radiuses to compute the accuracy from
|
|
acc_area (list): list of areas to compute the accuracy from.
|
|
acc_data (list): list of auxilliary data to compute the accuracy from.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
acc_radiuses=[],
|
|
acc_area=["country", "region", "sub-region", "city"],
|
|
aux_data=[],
|
|
):
|
|
super().__init__()
|
|
self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
|
self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
|
for acc in acc_radiuses:
|
|
self.add_state(
|
|
f"close_enough_points_{acc}",
|
|
default=torch.tensor(0.0),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
for acc in acc_area:
|
|
self.add_state(
|
|
f"close_enough_points_{acc}",
|
|
default=torch.tensor(0.0),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
self.add_state(
|
|
f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum"
|
|
)
|
|
self.acc_radius = acc_radiuses
|
|
self.acc_area = acc_area
|
|
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
|
|
self.aux = len(aux_data) > 0
|
|
self.aux_list = aux_data
|
|
if self.aux:
|
|
self.aux_count = {}
|
|
for col in self.aux_list:
|
|
self.add_state(
|
|
f"aux_{col}",
|
|
default=torch.tensor(0.0),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
|
|
def update(self, pred, gt):
|
|
haversine_distance = haversine(pred["gps"], gt["gps"])
|
|
for acc in self.acc_radius:
|
|
self.__dict__[f"close_enough_points_{acc}"] += (
|
|
haversine_distance < acc
|
|
).sum()
|
|
if len(self.acc_area) > 0:
|
|
area_pred, area_gt = reverse(pred["gps"], gt, self.acc_area)
|
|
for acc in self.acc_area:
|
|
self.__dict__[f"close_enough_points_{acc}"] += (
|
|
area_pred[acc] == area_gt["_".join(["unique", acc])]
|
|
).sum()
|
|
self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])])
|
|
self.haversine_sum += haversine_distance.sum()
|
|
self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum()
|
|
|
|
if self.aux:
|
|
if "land_cover" in self.aux_list:
|
|
col = "land_cover"
|
|
self.__dict__[f"aux_{col}"] += (
|
|
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
|
|
).sum()
|
|
if "road_index" in self.aux_list:
|
|
col = "road_index"
|
|
self.__dict__[f"aux_{col}"] += (
|
|
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
|
|
).sum()
|
|
if "drive_side" in self.aux_list:
|
|
col = "drive_side"
|
|
self.__dict__[f"aux_{col}"] += (
|
|
(pred[col] > 0.5).float() == gt[col]
|
|
).sum()
|
|
if "climate" in self.aux_list:
|
|
col = "climate"
|
|
self.__dict__[f"aux_{col}"] += (
|
|
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
|
|
).sum()
|
|
if "soil" in self.aux_list:
|
|
col = "soil"
|
|
self.__dict__[f"aux_{col}"] += (
|
|
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
|
|
).sum()
|
|
if "dist_sea" in self.aux_list:
|
|
col = "dist_sea"
|
|
self.__dict__[f"aux_{col}"] += (
|
|
(pred[col] - gt[col]).pow(2).sum(dim=1).sum()
|
|
)
|
|
|
|
self.count += pred["gps"].shape[0]
|
|
|
|
def compute(self):
|
|
output = {
|
|
"Haversine": self.haversine_sum / self.count,
|
|
"Geoguessr": self.geoguessr_sum / self.count,
|
|
}
|
|
for acc in self.acc_radius:
|
|
output[f"Accuracy_{acc}_km_radius"] = (
|
|
self.__dict__[f"close_enough_points_{acc}"] / self.count
|
|
)
|
|
for acc in self.acc_area:
|
|
output[f"Accuracy_{acc}"] = (
|
|
self.__dict__[f"close_enough_points_{acc}"]
|
|
/ self.__dict__[f"count_{acc}"]
|
|
)
|
|
|
|
if self.aux:
|
|
for col in self.aux_list:
|
|
output["_".join(["Accuracy", col])] = (
|
|
self.__dict__[f"aux_{col}"] / self.count
|
|
)
|
|
|
|
return output
|
|
|