forked from open-mmlab/OpenPCDet
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add quick demo codes and introduction to test on custom data
- Loading branch information
1 parent
e0bc968
commit b1cefad
Showing
5 changed files
with
165 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
## Quick Demo | ||
|
||
Here we provide a quick demo to test a pretrained model on the custom point cloud data and visualize the predicted results. | ||
|
||
We suppose you already followed the [INSTALL.md](INSTALL.md) to install the `OpenPCDet` repo successfully. | ||
|
||
1. Download the provided pretrained models as shown in the [README.md](../README.md). | ||
|
||
2. Make sure you have already installed the `mayavi` visualization tools. If not, you could install it as follows: | ||
``` | ||
pip install mayavi | ||
``` | ||
|
||
3. Prepare you custom point cloud data (skip this step if you use the original KITTI data). | ||
* You need to transform the coordinate of your custom point cloud to | ||
the unified normative coordinate of `OpenPCDet`, that is, x-axis points towards to front direction, | ||
y-axis points towards to the left direction, and z-axis points towards to the top direction. | ||
* (Optional) the z-axis origin of your point cloud coordinate should be about 1.6m above the ground surface, | ||
since currently the provided models are trained on the KITTI dataset. | ||
* Note to set the intensity information, and save your transformed custom data to `numpy file`: | ||
```python | ||
# Transform your point cloud data | ||
... | ||
|
||
# Save it to the file. | ||
# The shape of points should be (num_points, 4), that is [x, y, z, intensity], | ||
# and the range of intensity should be within [0, 1]. | ||
# If you doesn't have the intensity information, just set points[:, 3] = 0, | ||
np.save(`my_data.npy`, points) | ||
``` | ||
|
||
4. Run the demo with a pretrained model (e.g. PV-RCNN) and your custom point cloud data as follows: | ||
```shell | ||
python demo.py --cfg_file cfgs/kitti_models/pv_rcnn.yaml \ | ||
--ckpt pv_rcnn_8369.pth \ | ||
--data_path ${POINT_CLOUD_DATA} | ||
``` | ||
Here `${POINT_CLOUD_DATA}` could be the following format: | ||
* Your transformed custom data with a single numpy file like `my_data.npy`. | ||
* Your transformed custom data with a directory to test with multiple point cloud data. | ||
* The original KITTI `.bin` data within `data/kitti`, like `data/kitti/training/velodyne/000008.bin`. | ||
|
||
Then you could see the predicted results with visualized point cloud as follows: | ||
|
||
<p align="center"> | ||
<img src="docs/demo.png" width="95%"> | ||
</p> |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import torch | ||
import argparse | ||
import glob | ||
import numpy as np | ||
from pathlib import Path | ||
from pcdet.datasets import DatasetTemplate | ||
from pcdet.models import build_network, load_data_to_gpu | ||
from pcdet.config import cfg, cfg_from_yaml_file | ||
from pcdet.utils import common_utils | ||
from visual_utils import visualize_utils as V | ||
import mayavi.mlab as mlab | ||
|
||
|
||
class DemoDataset(DatasetTemplate): | ||
def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None, ext='.bin'): | ||
""" | ||
Args: | ||
root_path: | ||
dataset_cfg: | ||
class_names: | ||
training: | ||
logger: | ||
""" | ||
super().__init__( | ||
dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger | ||
) | ||
self.root_path = root_path | ||
self.ext = ext | ||
data_file_list = glob.glob(str(root_path / f'*{self.ext}')) if self.root_path.is_dir() else [self.root_path] | ||
|
||
data_file_list.sort() | ||
self.sample_file_list = data_file_list | ||
|
||
def __len__(self): | ||
return len(self.sample_file_list) | ||
|
||
def __getitem__(self, index): | ||
if self.ext == '.bin': | ||
points = np.fromfile(self.sample_file_list[index], dtype=np.float32).reshape(-1, 4) | ||
elif self.ext == '.npy': | ||
points = np.load(self.sample_file_list[index]) | ||
else: | ||
raise NotImplementedError | ||
|
||
input_dict = { | ||
'points': points, | ||
'frame_id': index, | ||
} | ||
|
||
data_dict = self.prepare_data(data_dict=input_dict) | ||
return data_dict | ||
|
||
|
||
def parse_config(): | ||
parser = argparse.ArgumentParser(description='arg parser') | ||
parser.add_argument('--cfg_file', type=str, default='cfgs/kitti_models/second.yaml', | ||
help='specify the config for demo') | ||
parser.add_argument('--data_path', type=str, default='demo_data', | ||
help='specify the point cloud data file or directory') | ||
parser.add_argument('--ckpt', type=str, default=None, help='specify the pretrained model') | ||
parser.add_argument('--ext', type=str, default='.bin', help='specify the extension of your point cloud data file') | ||
|
||
args = parser.parse_args() | ||
|
||
cfg_from_yaml_file(args.cfg_file, cfg) | ||
|
||
return args, cfg | ||
|
||
|
||
def main(): | ||
args, cfg = parse_config() | ||
logger = common_utils.create_logger() | ||
logger.info('-----------------Quick Demo of OpenPCDet-------------------------') | ||
demo_dataset = DemoDataset( | ||
dataset_cfg=cfg.DATA_CONFIG, class_names=cfg.CLASS_NAMES, training=False, | ||
root_path=Path(args.data_path), ext=args.ext, logger=logger | ||
) | ||
logger.info(f'Total number of samples: \t{len(demo_dataset)}') | ||
|
||
model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=demo_dataset) | ||
model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=True) | ||
model.cuda() | ||
model.eval() | ||
with torch.no_grad(): | ||
for idx, data_dict in enumerate(demo_dataset): | ||
logger.info(f'Visualized sample index: \t{idx + 1}') | ||
data_dict = demo_dataset.collate_batch([data_dict]) | ||
load_data_to_gpu(data_dict) | ||
pred_dicts, _ = model.forward(data_dict) | ||
|
||
V.draw_scenes( | ||
points=data_dict['points'][:, 1:], ref_boxes=pred_dicts[0]['pred_boxes'], | ||
ref_scores=pred_dicts[0]['pred_scores'], ref_labels=pred_dicts[0]['pred_labels'] | ||
) | ||
mlab.show(stop=True) | ||
|
||
logger.info('Demo done.') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
File renamed without changes.