Skip to content

Commit

Permalink
Update IROS 2023 evaluation code.
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-wiederer-mb committed Jan 2, 2024
1 parent dfbc5dc commit a338238
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 86 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ conda activate joodu
Install the API for the Shifts Vehicle Motion Prediction Dataset as described here:\
https://github.com/Shifts-Project/shifts/tree/main/sdc



### Install Argoverse API
The argoverse-api is used to convert the HD-map provided by Shifts into the Argoverse format, which is consumed by the trajectory prediction model.

Expand Down
6 changes: 3 additions & 3 deletions src/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def log_likelihood(
y: np.ndarray,
y_hat: np.ndarray,
pi: np.ndarray,
sigma: np.ndarray
sigma: np.ndarray = None,
) -> np.ndarray:
"""
Compute the Gaussian mixture log-likelihood.
Expand All @@ -21,13 +21,13 @@ def log_likelihood(
"""
displacement_norms_squared = np.sum(((y - y_hat)) ** 2 , axis=-1)

normalizing_const = np.log(2. * math.pi * sigma ** 2)

if isinstance(sigma, np.ndarray):
normalizing_const = np.log(2. * math.pi * sigma ** 2)
lse_args = np.log(pi) - np.sum(normalizing_const + np.divide(0.5 * displacement_norms_squared, sigma**2), axis=-1)

else:
sigma = 1.0
normalizing_const = np.log(2. * math.pi * sigma ** 2)
lse_args = np.log(pi) - np.sum(normalizing_const + 0.5 * displacement_norms_squared / sigma ** 2, axis=-1)

max_arg = lse_args.max()
Expand Down
14 changes: 1 addition & 13 deletions src/model/decoder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code adopted from https://github.com/ZikangZhou/HiVT
from typing import Tuple

import torch
Expand Down
14 changes: 1 addition & 13 deletions src/model/embedding.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code adopted from https://github.com/ZikangZhou/HiVT
from typing import List, Optional

import torch
Expand Down
14 changes: 1 addition & 13 deletions src/model/global_interactor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code adopted from https://github.com/ZikangZhou/HiVT
from typing import Optional

import torch
Expand Down
14 changes: 1 addition & 13 deletions src/model/local_encoder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code adopted from https://github.com/ZikangZhou/HiVT
from typing import Optional, Tuple

import torch
Expand Down
18 changes: 3 additions & 15 deletions src/model/traj_pred.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code adopted from https://github.com/ZikangZhou/HiVT
import pytorch_lightning as pl
import numpy as np
import torch
Expand All @@ -21,7 +9,7 @@
from src.model.global_interactor import GlobalInteractor
from src.model.local_encoder import LocalEncoder
from src.model.decoder import MLPDecoder
from utils import TemporalData, rotate_back_prediction
from utils import TemporalData, rotate_trajectory

class TrajPredEncoderDecoder(pl.LightningModule):
""" Trajectory prediction model, adopted from HiVT.
Expand Down Expand Up @@ -177,7 +165,7 @@ def validation_step(self, data, batch_idx):
y_hat, scale = loc_scale.chunk(2, dim=-1)

# log predictions
self.mu = torch.swapaxes(rotate_back_prediction(y_hat, data.rotate_mat), 0, 1) # N x K x T x 2
self.mu = torch.swapaxes(rotate_trajectory(y_hat, data.rotate_mat), 0, 1) # N x K x T x 2
self.pi = pi # N x K
self.scale = torch.swapaxes(scale, 0, 1) # N x K x T x 1

Expand Down
27 changes: 13 additions & 14 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch_geometric.data import Data


class TemporalData(Data):
"""
Code adopted from https://github.com/ZikangZhou/HiVT
"""


def __init__(self,
x: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -60,6 +51,9 @@ def __inc__(self, key, value):


class DistanceDropEdge(object):
"""
Code adopted from https://github.com/ZikangZhou/HiVT
"""

def __init__(self, max_distance: Optional[float] = None) -> None:
self.max_distance = max_distance
Expand All @@ -77,6 +71,11 @@ def __call__(self,


def init_weights(m: nn.Module) -> None:
"""
Initialize network weights for the trajectory prediction model.
Code adopted from https://github.com/ZikangZhou/HiVT
"""
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
Expand Down Expand Up @@ -144,7 +143,7 @@ def init_weights(m: nn.Module) -> None:
nn.init.zeros_(param)


def rotate_back_prediction(traj: torch.Tensor, rotate_mat: torch.Tensor) -> torch.Tensor:
def rotate_trajectory(traj: torch.Tensor, rotate_mat: torch.Tensor) -> torch.Tensor:
"""
Rotate trajectory given the rotation matrix.
Expand Down

0 comments on commit a338238

Please sign in to comment.