Skip to content

Commit

Permalink
Implementing mean pointwise l2 distance (nutonomy#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton authored Apr 1, 2020
1 parent 6888661 commit 711b2eb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions python-sdk/nuscenes/prediction/models/covernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def forward(self, image_tensor: torch.Tensor,
return logits


def l1_distance(lattice: torch.Tensor, ground_truth: torch.Tensor) -> torch.Tensor:
def mean_pointwise_l2_distance(lattice: torch.Tensor, ground_truth: torch.Tensor) -> torch.Tensor:
"""
Computes the index of the closest trajectory in the lattice as measured by l1 distance.
:param lattice: Lattice of pre-generated trajectories. Shape [num_modes, n_timesteps, state_dim]
:param ground_truth: Ground truth trajectory of agent. Shape [1, n_timesteps, state_dim].
:return: Index of closest mode in the lattice.
"""
stacked_ground_truth = ground_truth.repeat(lattice.shape[0], 1, 1)
return f.l1_loss(lattice, stacked_ground_truth, reduction='none').sum(dim=2).mean(dim=1).argmin()
return torch.pow(lattice - stacked_ground_truth, 2).sum(dim=2).sqrt().mean(dim=1).argmin()


class ConstantLatticeLoss:
Expand All @@ -82,7 +82,7 @@ class ConstantLatticeLoss:
"""

def __init__(self, lattice: Union[np.ndarray, torch.Tensor],
similarity_function: Callable[[torch.Tensor, torch.Tensor], int] = l1_distance):
similarity_function: Callable[[torch.Tensor, torch.Tensor], int] = mean_pointwise_l2_distance):
"""
Inits the loss.
:param lattice: numpy array of shape [n_modes, n_timesteps, state_dim]
Expand Down
10 changes: 5 additions & 5 deletions python-sdk/nuscenes/prediction/tests/test_covernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn.functional import cross_entropy

from nuscenes.prediction.models.backbone import ResNetBackbone
from nuscenes.prediction.models.covernet import l1_distance, ConstantLatticeLoss, CoverNet
from nuscenes.prediction.models.covernet import mean_pointwise_l2_distance, ConstantLatticeLoss, CoverNet


class TestCoverNet(unittest.TestCase):
Expand Down Expand Up @@ -36,15 +36,15 @@ def test_l1_distance(self):

# Should select the first mode
ground_truth = torch.arange(1, 13).reshape(6, 2).unsqueeze(0) + 2
self.assertEqual(l1_distance(lattice, ground_truth), 0)
self.assertEqual(mean_pointwise_l2_distance(lattice, ground_truth), 0)

# Should select the second mode
ground_truth = torch.arange(1, 13).reshape(6, 2).unsqueeze(0) * 3 + 4
self.assertEqual(l1_distance(lattice, ground_truth), 1)
self.assertEqual(mean_pointwise_l2_distance(lattice, ground_truth), 1)

# Should select the third mode
ground_truth = torch.arange(1, 13).reshape(6, 2).unsqueeze(0) * 6 + 10
self.assertEqual(l1_distance(lattice, ground_truth), 2)
self.assertEqual(mean_pointwise_l2_distance(lattice, ground_truth), 2)

def test_constant_lattice_loss(self):

Expand Down Expand Up @@ -75,7 +75,7 @@ def generate_trajectory(theta: float) -> torch.Tensor:

answer = cross_entropy(logits, torch.LongTensor([1, 1, 2, 0, 0]))

loss = ConstantLatticeLoss(lattice, l1_distance)
loss = ConstantLatticeLoss(lattice, mean_pointwise_l2_distance)
loss_value = loss(logits, ground_truth)

self.assertAlmostEqual(float(loss_value.detach().numpy()), float(answer.detach().numpy()))

0 comments on commit 711b2eb

Please sign in to comment.