From a338238e1980a88050dae3388a79792d55101d70 Mon Sep 17 00:00:00 2001 From: Julian Wiederer Date: Sun, 1 Oct 2023 17:47:47 +0200 Subject: [PATCH] Update IROS 2023 evaluation code. --- README.md | 2 -- src/metrics/metrics.py | 6 +++--- src/model/decoder.py | 14 +------------- src/model/embedding.py | 14 +------------- src/model/global_interactor.py | 14 +------------- src/model/local_encoder.py | 14 +------------- src/model/traj_pred.py | 18 +++--------------- utils.py | 27 +++++++++++++-------------- 8 files changed, 23 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index 5b5024c..5b721e9 100644 --- a/README.md +++ b/README.md @@ -55,8 +55,6 @@ conda activate joodu Install the API for the Shifts Vehicle Motion Prediction Dataset as described here:\ https://github.com/Shifts-Project/shifts/tree/main/sdc - - ### Install Argoverse API The argoverse-api is used to convert the HD-map provided by Shifts into the Argoverse format, which is consumed by the trajectory prediction model. diff --git a/src/metrics/metrics.py b/src/metrics/metrics.py index f987ad1..b44eb10 100755 --- a/src/metrics/metrics.py +++ b/src/metrics/metrics.py @@ -6,7 +6,7 @@ def log_likelihood( y: np.ndarray, y_hat: np.ndarray, pi: np.ndarray, - sigma: np.ndarray + sigma: np.ndarray = None, ) -> np.ndarray: """ Compute the Gaussian mixture log-likelihood. @@ -21,13 +21,13 @@ def log_likelihood( """ displacement_norms_squared = np.sum(((y - y_hat)) ** 2 , axis=-1) + normalizing_const = np.log(2. * math.pi * sigma ** 2) + if isinstance(sigma, np.ndarray): - normalizing_const = np.log(2. * math.pi * sigma ** 2) lse_args = np.log(pi) - np.sum(normalizing_const + np.divide(0.5 * displacement_norms_squared, sigma**2), axis=-1) else: sigma = 1.0 - normalizing_const = np.log(2. * math.pi * sigma ** 2) lse_args = np.log(pi) - np.sum(normalizing_const + 0.5 * displacement_norms_squared / sigma ** 2, axis=-1) max_arg = lse_args.max() diff --git a/src/model/decoder.py b/src/model/decoder.py index 922fdda..fd4a3dd 100644 --- a/src/model/decoder.py +++ b/src/model/decoder.py @@ -1,16 +1,4 @@ -# Copyright (c) 2022, Zikang Zhou. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Code adopted from https://github.com/ZikangZhou/HiVT from typing import Tuple import torch diff --git a/src/model/embedding.py b/src/model/embedding.py index b401766..1acb56d 100755 --- a/src/model/embedding.py +++ b/src/model/embedding.py @@ -1,16 +1,4 @@ -# Copyright (c) 2022, Zikang Zhou. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Code adopted from https://github.com/ZikangZhou/HiVT from typing import List, Optional import torch diff --git a/src/model/global_interactor.py b/src/model/global_interactor.py index e0abf83..e0be2b1 100644 --- a/src/model/global_interactor.py +++ b/src/model/global_interactor.py @@ -1,16 +1,4 @@ -# Copyright (c) 2022, Zikang Zhou. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Code adopted from https://github.com/ZikangZhou/HiVT from typing import Optional import torch diff --git a/src/model/local_encoder.py b/src/model/local_encoder.py index 9cc2656..edbce99 100755 --- a/src/model/local_encoder.py +++ b/src/model/local_encoder.py @@ -1,16 +1,4 @@ -# Copyright (c) 2022, Zikang Zhou. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Code adopted from https://github.com/ZikangZhou/HiVT from typing import Optional, Tuple import torch diff --git a/src/model/traj_pred.py b/src/model/traj_pred.py index 023decc..dd4170f 100644 --- a/src/model/traj_pred.py +++ b/src/model/traj_pred.py @@ -1,16 +1,4 @@ -# Copyright (c) 2022, Zikang Zhou. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Code adopted from https://github.com/ZikangZhou/HiVT import pytorch_lightning as pl import numpy as np import torch @@ -21,7 +9,7 @@ from src.model.global_interactor import GlobalInteractor from src.model.local_encoder import LocalEncoder from src.model.decoder import MLPDecoder -from utils import TemporalData, rotate_back_prediction +from utils import TemporalData, rotate_trajectory class TrajPredEncoderDecoder(pl.LightningModule): """ Trajectory prediction model, adopted from HiVT. @@ -177,7 +165,7 @@ def validation_step(self, data, batch_idx): y_hat, scale = loc_scale.chunk(2, dim=-1) # log predictions - self.mu = torch.swapaxes(rotate_back_prediction(y_hat, data.rotate_mat), 0, 1) # N x K x T x 2 + self.mu = torch.swapaxes(rotate_trajectory(y_hat, data.rotate_mat), 0, 1) # N x K x T x 2 self.pi = pi # N x K self.scale = torch.swapaxes(scale, 0, 1) # N x K x T x 1 diff --git a/utils.py b/utils.py index 4e45227..841646e 100755 --- a/utils.py +++ b/utils.py @@ -1,16 +1,3 @@ -# Copyright (c) 2022, Zikang Zhou. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -18,6 +5,10 @@ class TemporalData(Data): + """ + Code adopted from https://github.com/ZikangZhou/HiVT + """ + def __init__(self, x: Optional[torch.Tensor] = None, @@ -60,6 +51,9 @@ def __inc__(self, key, value): class DistanceDropEdge(object): + """ + Code adopted from https://github.com/ZikangZhou/HiVT + """ def __init__(self, max_distance: Optional[float] = None) -> None: self.max_distance = max_distance @@ -77,6 +71,11 @@ def __call__(self, def init_weights(m: nn.Module) -> None: + """ + Initialize network weights for the trajectory prediction model. + Code adopted from https://github.com/ZikangZhou/HiVT + + """ if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: @@ -144,7 +143,7 @@ def init_weights(m: nn.Module) -> None: nn.init.zeros_(param) -def rotate_back_prediction(traj: torch.Tensor, rotate_mat: torch.Tensor) -> torch.Tensor: +def rotate_trajectory(traj: torch.Tensor, rotate_mat: torch.Tensor) -> torch.Tensor: """ Rotate trajectory given the rotation matrix.