Skip to content

Commit

Permalink
[Imporve] Using train_step instead of forward in PreciseBNHook (o…
Browse files Browse the repository at this point in the history
…pen-mmlab#964)

* fix precise BN hook when using MLU

* fix unit tests
  • Loading branch information
Ezra-Yu authored Aug 11, 2022
1 parent b366897 commit e54cfd6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion mmcls/core/hook/precise_bn_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def update_bn_stats(model: nn.Module,
prog_bar = mmcv.ProgressBar(num_iter)

for data in itertools.islice(loader, num_iter):
model(**data)
model.train_step(data)
for i, bn in enumerate(bn_layers):
running_means[i] += bn.running_mean / num_iter
running_vars[i] += bn.running_var / num_iter
Expand Down
21 changes: 15 additions & 6 deletions tests/test_runtime/test_preciseBN_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils.data import DataLoader, Dataset

from mmcls.core.hook import PreciseBNHook
from mmcls.models.classifiers import BaseClassifier


class ExampleDataset(Dataset):
Expand Down Expand Up @@ -41,7 +42,7 @@ def __len__(self):
return 12


class ExampleModel(nn.Module):
class ExampleModel(BaseClassifier):

def __init__(self):
super().__init__()
Expand All @@ -52,7 +53,17 @@ def __init__(self):
def forward(self, imgs, return_loss=False):
return self.bn(self.conv(imgs))

def train_step(self, data_batch, optimizer, **kwargs):
def simple_test(self, img, img_metas=None, **kwargs):
return {}

def extract_feat(self, img, stage='neck'):
return ()

def forward_train(self, img, gt_label, **kwargs):
return {'loss': 0.5}

def train_step(self, data_batch, optimizer=None, **kwargs):
self.forward(**data_batch)
outputs = {
'loss': 0.5,
'log_vars': {
Expand Down Expand Up @@ -234,10 +245,8 @@ def test_precise_bn():
mean = np.mean([np.mean(batch) for batch in imgs_list])
# bassel correction used in Pytorch, therefore ddof=1
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
assert np.equal(mean, np.array(
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
assert np.equal(var, np.array(
model.bn.running_var)), (var, np.array(model.bn.running_var))
assert np.equal(mean, model.bn.running_mean)
assert np.equal(var, model.bn.running_var)

@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
Expand Down

0 comments on commit e54cfd6

Please sign in to comment.