Skip to content

Commit

Permalink
check dataset type by subclass instead of names
Browse files Browse the repository at this point in the history
  • Loading branch information
hellock committed Apr 29, 2019
1 parent 6fe5ccd commit ebc8312
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

from mmdet import datasets
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader
Expand Down Expand Up @@ -80,14 +81,16 @@ def _dist_train(model, dataset, cfg, validate=False):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_dataset_cfg = cfg.data.val
if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg))
else:
if cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
dataset_type = getattr(datasets, val_dataset_cfg.type)
if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg))
else:
runner.register_hook(DistEvalmAPHook(cfg.data.val))
runner.register_hook(DistEvalmAPHook(val_dataset_cfg))

if cfg.resume_from:
runner.resume(cfg.resume_from)
Expand Down

0 comments on commit ebc8312

Please sign in to comment.