Skip to content

Commit

Permalink
fix seed for sample_points data_processsor, update configs
Browse files Browse the repository at this point in the history
  • Loading branch information
sshaoshuai committed Jul 28, 2020
1 parent c4033be commit d2d32f6
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 19 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,13 @@ All models are trained with 8 GTX 1080Ti GPUs and are available for download.

| | training time | Car | Pedestrian | Cyclist | download |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:---------:|
| [PointPillar](tools/cfgs/kitti_models/pointpillar.yaml) |~95 mins| 77.28 | - | - | [model-18M](https://drive.google.com/file/d/1wMxWTpU1qUoY3DsCH31WJmvJxcjFXKlm/view?usp=sharing) |
| [SECOND](tools/cfgs/kitti_models/second.yaml) | ~2 hours | 78.62 | - | - | [model-20M](https://drive.google.com/file/d/1-01zsPOsqanZQqIIyy7FpNXStL3y4jdR/view?usp=sharing) |
| [SECOND-MultiHead](tools/cfgs/kitti_models/second_multihead.yaml) | - | - | - | - | ongoing |
| [PointPillar](tools/cfgs/kitti_models/pointpillar.yaml) |~95 mins| 77.28 | 52.29 | 62.68 | [model-18M](https://drive.google.com/file/d/1wMxWTpU1qUoY3DsCH31WJmvJxcjFXKlm/view?usp=sharing) |
| [SECOND](tools/cfgs/kitti_models/second.yaml) | ~2 hours | 78.62 | 52.98 | 67.15 | [model-20M](https://drive.google.com/file/d/1-01zsPOsqanZQqIIyy7FpNXStL3y4jdR/view?usp=sharing) |
| [PointRCNN](tools/cfgs/kitti_models/pointrcnn.yaml) | ~3 hours | 78.70 | - | - | ongoing|
| [PointRCNN-IoU](tools/cfgs/kitti_models/pointrcnn_iou.yaml) | ~3 hours | 78.70 | - | - | ongoing|
| [Part-A^2-Free](tools/cfgs/kitti_models/PartA2_free.yaml) | ~4 hours| 78.72 | 65.99 | 74.29 | [model-244M](https://drive.google.com/file/d/10GK1aCkLqxGNeX3lVu8cLZyE0G8002hY/view?usp=sharing) |
| [PointRCNN-IoU](tools/cfgs/kitti_models/pointrcnn_iou.yaml) | ~3 hours | 78.90 | 54.62 | 71.52 | [model-]()|
| [Part-A^2-Free](tools/cfgs/kitti_models/PartA2_free.yaml) | ~4 hours| 78.72 | 65.99 | 74.29 | [model-](https://drive.google.com/file/d/10GK1aCkLqxGNeX3lVu8cLZyE0G8002hY/view?usp=sharing) |
| [Part-A^2-Anchor](tools/cfgs/kitti_models/PartA2.yaml) | ~5 hours| 79.40 | - | - | [model-244M](https://drive.google.com/file/d/10GK1aCkLqxGNeX3lVu8cLZyE0G8002hY/view?usp=sharing) |
| [PV-RCNN](tools/cfgs/kitti_models/pv_rcnn.yaml) | ~6 hours| 83.69 | - | - | [model-50M](https://drive.google.com/file/d/1lIOq4Hxr0W3qsX83ilQv0nk1Cls6KAr-/view?usp=sharing) |
| [PV-RCNN](tools/cfgs/kitti_models/pv_rcnn.yaml) | ~6 hours| 83.69 | 54.49 | 69.47 | [model-50M](https://drive.google.com/file/d/1lIOq4Hxr0W3qsX83ilQv0nk1Cls6KAr-/view?usp=sharing) |

### NuScenes 3D Object Detection Baselines
All models are trained with 8 GTX 1080Ti GPUs and are available for download.
Expand Down
3 changes: 3 additions & 0 deletions pcdet/datasets/processor/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def sample_points(self, data_dict=None, config=None):
if num_points == -1:
return data_dict

if config.FIX_SEED[self.mode]:
np.random.seed(512)

points = data_dict['points']
if num_points < len(points):
pts_depth = np.linalg.norm(points[:, 0:3], axis=1)
Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/backbones_2d/base_bev_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, model_cfg, input_channels):
self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0:
stride = upsample_strides[idx]
if stride > 1:
if stride >= 1:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
Expand Down
4 changes: 4 additions & 0 deletions tools/cfgs/kitti_models/pointrcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ DATA_CONFIG:
'train': 16384,
'test': 16384
}
FIX_SEED: {
'train': False,
'test': True
}

- NAME: shuffle_points
SHUFFLE_ENABLED: {
Expand Down
6 changes: 3 additions & 3 deletions tools/cfgs/nuscenes_models/cbgs_pp_multihead.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,11 @@ MODEL:
EVAL_METRIC: kitti

NMS_CONFIG:
MULTI_CLASSES_NMS: False
MULTI_CLASSES_NMS: True
NMS_TYPE: nms_gpu
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 100
NMS_PRE_MAXSIZE: 1000
NMS_POST_MAXSIZE: 83


OPTIMIZATION:
Expand Down
4 changes: 2 additions & 2 deletions tools/cfgs/nuscenes_models/cbgs_second_multihead.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,11 @@ MODEL:
EVAL_METRIC: kitti

NMS_CONFIG:
MULTI_CLASSES_NMS: False
MULTI_CLASSES_NMS: True
NMS_TYPE: nms_gpu
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 1000
NMS_POST_MAXSIZE: 100
NMS_POST_MAXSIZE: 83


OPTIMIZATION:
Expand Down
12 changes: 5 additions & 7 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import datetime
import argparse
from pathlib import Path
import torch.distributed as dist
from pcdet.datasets import build_dataloader
from pcdet.models import build_network
from pcdet.utils import common_utils
Expand All @@ -19,12 +18,10 @@ def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')

parser.add_argument('--batch_size', type=int, default=16, required=False, help='batch size for training')
parser.add_argument('--epochs', type=int, default=80, required=False, help='Number of epochs to train for')
parser.add_argument('--batch_size', type=int, default=None, required=False, help='batch size for training')
parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')
parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment')
parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from')
parser.add_argument('--mgpus', action='store_true', default=False, help='whether to use multiple gpu')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training')
parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
Expand Down Expand Up @@ -133,11 +130,13 @@ def main():
if args.launcher == 'none':
dist_test = False
else:
args.batch_size, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
args.batch_size, args.tcp_port, args.local_rank, backend='nccl'
total_gpus, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
args.tcp_port, args.local_rank, backend='nccl'
)
dist_test = True

args.batch_size = cfg.OPTIMIZATION.BATCH_SIZE_PER_GPU if args.batch_size is None else args.batch_size

output_dir = cfg.ROOT_DIR / 'output' / cfg.EXP_GROUP_PATH / cfg.TAG / args.extra_tag
output_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -163,7 +162,6 @@ def main():
logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)

if dist_test:
total_gpus = dist.get_world_size()
logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
for key, val in vars(args).items():
logger.info('{:16} {}'.format(key, val))
Expand Down

0 comments on commit d2d32f6

Please sign in to comment.