Skip to content

Commit

Permalink
training code release
Browse files Browse the repository at this point in the history
  • Loading branch information
huang-yh committed Dec 16, 2023
1 parent f8c1d01 commit c115c8b
Show file tree
Hide file tree
Showing 14 changed files with 1,532 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SelfOcc empowers 3D autonomous driving world models (e.g., [OccWorld](https://github.com/wzzheng/OccWorld)) with scalable 3D representations, paving the way for **interpretable end-to-end large driving models**.

## News
- **[2023/12/16]** Training code release.
- **[2023/11/28]** Evaluation code release.
- **[2023/11/20]** Paper released on [arXiv](https://arxiv.org/abs/2311.12754).
- **[2023/11/20]** Demo release.
Expand Down
50 changes: 49 additions & 1 deletion docs/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,36 @@

**Please ensure you have prepared the environment and datasets.**

Training code will be released soon.
[23/12/16 Update] Please update the timm package to 0.9.2 to run the training script.


# 3D Occupancy Prediction

## NuScenes

### Training

```
python train.py --py-config config/nuscenes/nuscenes_occ.py --work-dir out/nuscenes/occ_train --depth-metric
```

### Evaluation

Download model weights [HERE](https://cloud.tsinghua.edu.cn/f/831c104c82a244e9878a/) and put it under out/nuscenes/occ/
```
python eval_iou.py --py-config config/nuscenes/nuscenes_occ.py --work-dir out/nuscenes/occ --resume-from out/nuscenes/occ/model_state_dict.pth --occ3d --resolution 0.4 --sem --use-mask --scene-size 4
```

## SemanticKITTI

### Training

```
python train.py --py-config config/kitti/kitti_occ.py --work-dir out/kitti/occ_train --depth-metric --dataset kitti
```

### Evaluation

Download model weights [HERE](https://cloud.tsinghua.edu.cn/f/3c09a5e8f5b94fa29289/) and put it under out/kitti/occ/
```
python eval_iou_kitti.py --py-config config/kitti/kitti_occ.py --work-dir out/kitti/occ --resume-from out/kitti/occ/model_state_dict.pth
Expand All @@ -27,13 +43,29 @@ python eval_iou_kitti.py --py-config config/kitti/kitti_occ.py --work-dir out/ki

## NuScenes

### Training

```
python train.py --py-config config/nuscenes/nuscenes_novel_depth.py --work-dir out/nuscenes/novel_depth_train --depth-metric
```

### Evaluation

Download model weights [HERE](https://cloud.tsinghua.edu.cn/f/2d217cd298a34ed19039/) and put it under out/nuscenes/novel_depth/
```
python eval_novel_depth.py --py-config config/nuscenes/nuscenes_novel_depth.py --work-dir out/nuscenes/novel_depth --resume-from out/nuscenes/novel_depth/model_state_dict.pth
```

## SemanticKITTI

### Training

```
python train.py --py-config config/kitti/kitti_novel_depth.py --work-dir out/kitti/novel_depth_train --depth-metric --dataset kitti
```

### Evaluation

Download model weights [HERE](https://cloud.tsinghua.edu.cn/f/7280a44340fd440cba7c/) and put it under out/kitti/novel_depth/
```
python eval_novel_depth_kitti.py --py-config config/kitti/kitti_novel_depth.py --work-dir out/kitti/novel_depth --resume-from out/kitti/novel_depth/model_state_dict.pth
Expand All @@ -44,6 +76,14 @@ python eval_novel_depth_kitti.py --py-config config/kitti/kitti_novel_depth.py -

## nuScenes

### Training

```
python train.py --py-config config/nuscenes/nuscenes_depth.py --work-dir out/nuscenes/depth_train --depth-metric
```

### Evaluation

Download model weights [HERE](https://cloud.tsinghua.edu.cn/f/1a722b9139234542ae1e/) and put it under out/nuscenes/depth/
```
python eval_depth.py --py-config config/nuscenes/nuscenes_depth.py --work-dir out/nuscenes/depth --resume-from out/nuscenes/depth/model_state_dict.pth --depth-metric --batch 90000
Expand All @@ -54,6 +94,14 @@ Note that evaluating at a resolution (450\*800) of 1:2 against the raw image (90

## KITTI-2015

### Training

```
python train.py --py-config config/kitti_raw/kitti_raw_depth.py --work-dir out/kitti_raw/depth_train --depth-metric --dataset kitti
```

### Evaluation

Download model weights [HERE](https://cloud.tsinghua.edu.cn/f/f87f6876569e4fdeb967/) and put it under out/kitti_raw/depth/
```
python eval_depth.py --py-config config/kitti_raw/kitti_raw_depth.py --work-dir out/kitti_raw/depth --resume-from out/kitti_raw/depth/model_state_dict.pth --depth-metric --dataset kitti_raw
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pip install -e .

**f. Install other packages and deal with package versions.**
```shell
pip install pillow==8.4.0 typing_extensions==4.8.0 torchmetrics==0.9.3
pip install pillow==8.4.0 typing_extensions==4.8.0 torchmetrics==0.9.3 timm==0.9.2
```


Expand Down
11 changes: 11 additions & 0 deletions loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from mmengine.registry import Registry
OPENOCC_LOSS = Registry('openocc_loss')

from .multi_loss import MultiLoss
from .rgb_loss_ms import RGBLossMS, SemLossMS, SemCELossMS
from .reproj_loss_mono_multi_new import ReprojLossMonoMultiNew
from .reproj_loss_mono_multi_new_combine import ReprojLossMonoMultiNewCombine
from .edge_loss_3d_ms import EdgeLoss3DMS
from .eikonal_loss import EikonalLoss
from .sparsity_loss import SparsityLoss, HardSparsityLoss, SoftSparsityLoss, AdaptiveSparsityLoss
from .second_grad_loss import SecondGradLoss
39 changes: 39 additions & 0 deletions loss/base_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch.nn as nn
from utils.tb_wrapper import WrappedTBWriter
if 'selfocc' in WrappedTBWriter._instance_dict:
writer = WrappedTBWriter.get_instance('selfocc')
else:
writer = None

class BaseLoss(nn.Module):

""" Base loss class.
args:
weight: weight of current loss.
input_keys: keys for actual inputs to calculate_loss().
Since "inputs" may contain many different fields, we use input_keys
to distinguish them.
loss_func: the actual loss func to calculate loss.
"""

def __init__(
self,
weight=1.0,
input_dict={
'input': 'input'},
**kwargs):
super().__init__()
self.weight = weight
self.input_dict = input_dict
self.loss_func = lambda: 0
self.writer = writer

# def calculate_loss(self, **kwargs):
# return self.loss_func(*[kwargs[key] for key in self.input_keys])

def forward(self, inputs):
actual_inputs = {}
for input_key, input_val in self.input_dict.items():
actual_inputs.update({input_key: inputs[input_val]})
# return self.weight * self.calculate_loss(**actual_inputs)
return self.weight * self.loss_func(**actual_inputs)
79 changes: 79 additions & 0 deletions loss/edge_loss_3d_ms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from .base_loss import BaseLoss
from . import OPENOCC_LOSS
import torch.nn.functional as F


def get_smooth_loss(disp, img):
"""Computes the smoothness loss for a disparity image
The color image is used for edge-aware smoothness
"""
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])

grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)

grad_disp_x *= torch.exp(-grad_img_x)
grad_disp_y *= torch.exp(-grad_img_y)

return grad_disp_x.mean() + grad_disp_y.mean()


@OPENOCC_LOSS.register_module()
class EdgeLoss3DMS(BaseLoss):

def __init__(self, weight=1.0, input_dict=None, **kwargs):
super().__init__(weight)

if input_dict is None:
self.input_dict = {
'curr_imgs': 'curr_imgs',
'ms_depths': 'ms_depths',
'ms_rays': 'ms_rays'
}
else:
self.input_dict = input_dict
self.img_size = kwargs.get('img_size', [768, 1600])
self.ray_resize = kwargs.get('ray_resize', None)
self.use_inf_mask = kwargs.get('use_inf_mask', False)
# self.inf_dist = kwargs.get('inf_dist', 1e6)
assert self.ray_resize is not None
self.loss_func = self.edge_loss

def edge_loss(self, curr_imgs, ms_depths, ms_rays, ms_accs=None, max_depths=None):
# curr_imgs: B, N, C, H, W
# depth: B, N, R
# rays: R, 2
if self.use_inf_mask:
assert ms_accs is not None and max_depths is not None
if not isinstance(ms_rays, list):
ms_rays = [ms_rays] * len(ms_depths)
bs, num_cams, num_rays = ms_depths[0].shape

tot_loss = 0.
for scale, (depth, rays) in enumerate(zip(ms_depths, ms_rays)):
pixel_curr = rays.clone().reshape(1, 1, num_rays, 2).repeat(
bs * num_cams, 1, 1, 1) # bs*N, 1, R, 2
pixel_curr[..., 0] /= self.img_size[1]
pixel_curr[..., 1] /= self.img_size[0]
pixel_curr = pixel_curr * 2 - 1
rgb_curr = F.grid_sample(
curr_imgs.flatten(0, 1),
pixel_curr,
mode='bilinear',
padding_mode='border',
align_corners=True) # bs*N, 3, 1, R
rgb_curr = rgb_curr.reshape(bs * num_cams, -1, *self.ray_resize)

if self.use_inf_mask:
depth = depth * ms_accs[scale] + max_depths[scale] * (1 - ms_accs[scale])

depth = depth.reshape(bs * num_cams, 1, *self.ray_resize)
mean_depth = depth.mean(2, True).mean(3, True)
norm_depth = depth / (mean_depth + 1e-6)
smooth_loss = get_smooth_loss(norm_depth, rgb_curr)

tot_loss += smooth_loss

return tot_loss / len(ms_depths)
22 changes: 22 additions & 0 deletions loss/eikonal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .base_loss import BaseLoss
from . import OPENOCC_LOSS


@OPENOCC_LOSS.register_module()
class EikonalLoss(BaseLoss):

def __init__(self, weight=1.0, input_dict=None, **kwargs):
super().__init__(weight)

if input_dict is None:
self.input_dict = {
'eik_grad': 'eik_grad',
}
else:
self.input_dict = input_dict
self.loss_func = self.eikonal_loss

def eikonal_loss(self, eik_grad):
grad_theta = eik_grad
eikonal_loss = ((grad_theta.norm(2, dim=-1) - 1) ** 2).mean()
return eikonal_loss
44 changes: 44 additions & 0 deletions loss/multi_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch.nn as nn
from . import OPENOCC_LOSS
from utils.tb_wrapper import WrappedTBWriter
if 'selfocc' in WrappedTBWriter._instance_dict:
writer = WrappedTBWriter.get_instance('selfocc')
else:
writer = None

@OPENOCC_LOSS.register_module()
class MultiLoss(nn.Module):

def __init__(self, loss_cfgs):
super().__init__()

assert isinstance(loss_cfgs, list)
self.num_losses = len(loss_cfgs)

losses = []
for loss_cfg in loss_cfgs:
losses.append(OPENOCC_LOSS.build(loss_cfg))
self.losses = nn.ModuleList(losses)
self.iter_counter = 0

def forward(self, inputs):

loss_dict = {}
tot_loss = 0.
for loss_func in self.losses:
loss = loss_func(inputs)
tot_loss += loss
loss_dict.update({
loss_func.__class__.__name__: \
loss.detach().item()
})
if writer and self.iter_counter % 10 == 0:
writer.add_scalar(
f'loss/{loss_func.__class__.__name__}',
loss.detach().item(), self.iter_counter)
if writer and self.iter_counter % 10 == 0:
writer.add_scalar(
'loss/total', tot_loss.detach().item(), self.iter_counter)
self.iter_counter += 1

return tot_loss, loss_dict
Loading

0 comments on commit c115c8b

Please sign in to comment.