Skip to content

Commit

Permalink
Merge branch 'master' into dev_v0.6
Browse files Browse the repository at this point in the history
  • Loading branch information
sshaoshuai committed Aug 20, 2022
2 parents 5e21e42 + a41e331 commit 7e2d56b
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ and you could refer to `data/waymo/waymo_processed_data_v0_5_0` to see how many
```python
python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_infos \
--cfg_file tools/cfgs/dataset_configs/waymo_dataset.yaml
# Ignore 'CUDA_ERROR_NO_DEVICE' error as this process does not require GPU.
```

Note that you do not need to install `waymo-open-dataset` if you have already processed the data before and do not need to evaluate with official Waymo Metrics.
Expand Down
20 changes: 17 additions & 3 deletions tools/eval_utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def statistics_info(cfg, ret_dict, metric, disp_dict):
'(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num'])


def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, save_to_file=False, result_dir=None):
def eval_one_epoch(cfg, args, model, dataloader, epoch_id, logger, dist_test=False, result_dir=None):
result_dir.mkdir(parents=True, exist_ok=True)

final_output_dir = result_dir / 'final_result' / 'data'
if save_to_file:
if args.save_to_file:
final_output_dir.mkdir(parents=True, exist_ok=True)

metric = {
Expand All @@ -37,6 +37,10 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
class_names = dataset.class_names
det_annos = []

if getattr(args, 'infer_time', False):
start_iter = int(len(dataloader) * 0.1)
infer_time_meter = common_utils.AverageMeter()

logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id)
if dist_test:
num_gpus = torch.cuda.device_count()
Expand All @@ -53,14 +57,24 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
start_time = time.time()
for i, batch_dict in enumerate(dataloader):
load_data_to_gpu(batch_dict)

if getattr(args, 'infer_time', False):
start_time = time.time()

with torch.no_grad():
pred_dicts, ret_dict = model(batch_dict)
disp_dict = {}

if getattr(args, 'infer_time', False):
inference_time = time.time() - start_time
infer_time_meter.update(inference_time * 1000)
# use ms to measure inference time
disp_dict['infer_time'] = f'{infer_time_meter.val:.2f}({infer_time_meter.avg:.2f})'

statistics_info(cfg, ret_dict, metric, disp_dict)
annos = dataset.generate_prediction_dicts(
batch_dict, pred_dicts, class_names,
output_path=final_output_dir if save_to_file else None
output_path=final_output_dir if args.save_to_file else None
)
det_annos += annos
if cfg.LOCAL_RANK == 0:
Expand Down
82 changes: 82 additions & 0 deletions tools/process_tools/create_integrated_databse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
import pickle as pkl
from pathlib import Path
import tqdm
import copy


def create_integrated_db_with_infos(args, root_path):
"""
Args:
args:
Returns:
"""
# prepare
db_infos_path = root_path / args.src_db_info
db_info_global_path = str(db_infos_path)[:-4] + '_global' + '.pkl'
global_db_path = root_path / (args.new_db_name + '.npy')

db_infos = pkl.load(open(db_infos_path, 'rb'))
db_info_global = copy.deepcopy(db_infos)
start_idx = 0
global_db_list = []

for category, class_info in db_infos.items():
print('>>> Start processing %s' % category)
for idx, info in tqdm.tqdm(enumerate(class_info), total=len(class_info)):
obj_path = root_path / info['path']
obj_points = np.fromfile(str(obj_path), dtype=np.float32).reshape(
[-1, args.num_point_features])
num_points = obj_points.shape[0]
db_info_global[category][idx]['global_data_offset'] = (start_idx, start_idx + num_points)
start_idx += num_points
global_db_list.append(obj_points)

global_db = np.concatenate(global_db_list)

with open(global_db_path, 'wb') as f:
np.save(f, global_db)

with open(db_info_global_path, 'wb') as f:
pkl.dump(db_info_global, f)

print(f"Successfully create integrated database at {global_db_path}")
print(f"Successfully create integrated database info at {db_info_global_path}")

return db_info_global, global_db


def verify(info, whole_db, root_path, num_point_features):
obj_path = root_path / info['path']
obj_points = np.fromfile(str(obj_path), dtype=np.float32).reshape([-1, num_point_features])
mean_origin = obj_points.mean()

start_idx, end_idx = info['global_data_offset']
obj_points_new = whole_db[start_idx:end_idx]
mean_new = obj_points_new.mean()

assert mean_origin == mean_new

print("Verification pass!")


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--root_path', type=str, default=None, help='specify the root path')
parser.add_argument('--src_db_info', type=str, default='waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl', help='')
parser.add_argument('--new_db_name', type=str, default='waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global', help='')
parser.add_argument('--num_point_features', type=int, default=5,
help='number of feature channels for points')
parser.add_argument('--class_name', type=str, default='Vehicle',
help='category name for verification')

args = parser.parse_args()

root_path = Path(args.root_path)

db_infos_global, whole_db = create_integrated_db_with_infos(args, root_path)
# simple verify
verify(db_infos_global[args.class_name][0], whole_db, root_path, args.num_point_features)
13 changes: 9 additions & 4 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def parse_config():
parser.add_argument('--eval_all', action='store_true', default=False, help='whether to evaluate all checkpoints')
parser.add_argument('--ckpt_dir', type=str, default=None, help='specify a ckpt directory to be evaluated if needed')
parser.add_argument('--save_to_file', action='store_true', default=False, help='')
parser.add_argument('--infer_time', action='store_true', default=False, help='calculate inference latency')

args = parser.parse_args()

Expand All @@ -60,8 +61,8 @@ def eval_single_ckpt(model, test_loader, args, eval_output_dir, logger, epoch_id

# start evaluation
eval_utils.eval_one_epoch(
cfg, model, test_loader, epoch_id, logger, dist_test=dist_test,
result_dir=eval_output_dir, save_to_file=args.save_to_file
cfg, args, model, test_loader, epoch_id, logger, dist_test=dist_test,
result_dir=eval_output_dir
)


Expand Down Expand Up @@ -118,8 +119,8 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir
# start evaluation
cur_result_dir = eval_output_dir / ('epoch_%s' % cur_epoch_id) / cfg.DATA_CONFIG.DATA_SPLIT['test']
tb_dict = eval_utils.eval_one_epoch(
cfg, model, test_loader, cur_epoch_id, logger, dist_test=dist_test,
result_dir=cur_result_dir, save_to_file=args.save_to_file
cfg, model, args, test_loader, cur_epoch_id, logger, dist_test=dist_test,
result_dir=cur_result_dir
)

if cfg.LOCAL_RANK == 0:
Expand All @@ -134,6 +135,10 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir

def main():
args, cfg = parse_config()

if args.infer_time:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

if args.launcher == 'none':
dist_test = False
total_gpus = 1
Expand Down

0 comments on commit 7e2d56b

Please sign in to comment.