Skip to content

Commit 258e62d

Browse files
committed
code cleanup
1 parent ea1ab3f commit 258e62d

File tree

2 files changed

+12
-30
lines changed

2 files changed

+12
-30
lines changed

deeplab.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.utils.model_zoo as model_zoo
55

66

7-
__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152', 'pnasnet5', 'hdarts']
7+
__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
88

99

1010
model_urls = {
@@ -115,13 +115,9 @@ def forward(self, x):
115115

116116
class ResNet(nn.Module):
117117

118-
def __init__(self, block, layers, num_classes, num_groups=None, sync_bn=False, beta=False):
118+
def __init__(self, block, layers, num_classes, num_groups=None, beta=False):
119119
self.inplanes = 64
120-
if sync_bn:
121-
from encoding.nn import BatchNorm2d
122-
else:
123-
from torch.nn import BatchNorm2d
124-
self._norm = lambda planes, momentum=0.05: BatchNorm2d(planes, momentum=momentum) if num_groups is None else nn.GroupNorm(num_groups, planes)
120+
self._norm = lambda planes, momentum=0.05: nn.BatchNorm2d(planes, momentum=momentum) if num_groups is None else nn.GroupNorm(num_groups, planes)
125121

126122
super(ResNet, self).__init__()
127123
if not beta:
@@ -146,7 +142,7 @@ def __init__(self, block, layers, num_classes, num_groups=None, sync_bn=False, b
146142
if isinstance(m, nn.Conv2d):
147143
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
148144
m.weight.data.normal_(0, math.sqrt(2. / n))
149-
elif isinstance(m, BatchNorm2d) or isinstance(m, nn.GroupNorm):
145+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
150146
m.weight.data.fill_(1)
151147
m.bias.data.zero_()
152148

main.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@
4040
help='train from scratch')
4141
parser.add_argument('--freeze_bn', action='store_true', default=False,
4242
help='freeze batch normalization parameters')
43-
parser.add_argument('--sync_bn', action='store_true', default=False,
44-
help='sync batch normalization across gpu')
4543
parser.add_argument('--beta', action='store_true', default=False,
4644
help='resnet101 beta')
4745
parser.add_argument('--crop_size', type=int, default=513,
@@ -70,39 +68,27 @@ def main():
7068
pretrained=(not args.scratch),
7169
num_classes=len(dataset.CLASSES),
7270
num_groups=args.groups,
73-
sync_bn=args.sync_bn,
7471
beta=args.beta)
7572
else:
7673
raise ValueError('Unknown backbone: {}'.format(args.backbone))
7774

7875
if args.train:
7976
criterion = nn.CrossEntropyLoss(ignore_index=255)
80-
if args.sync_bn:
81-
from encoding.parallel import DataParallelModel, DataParallelCriterion
82-
criterion = DataParallelCriterion(criterion).cuda()
83-
model = DataParallelModel(model).cuda()
84-
else:
85-
model = nn.DataParallel(model).cuda()
77+
model = nn.DataParallel(model).cuda()
8678
model.train()
8779
if args.freeze_bn:
8880
for m in model.modules():
8981
if isinstance(m, nn.BatchNorm2d):
9082
m.eval()
9183
m.weight.requires_grad = False
9284
m.bias.requires_grad = False
93-
if args.backbone == 'resnet101':
94-
backbone_params = (
95-
list(model.module.conv1.parameters()) +
96-
list(model.module.bn1.parameters()) +
97-
list(model.module.layer1.parameters()) +
98-
list(model.module.layer2.parameters()) +
99-
list(model.module.layer3.parameters()) +
100-
list(model.module.layer4.parameters()))
101-
else:
102-
backbone_params = (
103-
list(model.module.stem0.parameters()) +
104-
list(model.module.stem1.parameters()) +
105-
list(model.module.cells.parameters()))
85+
backbone_params = (
86+
list(model.module.conv1.parameters()) +
87+
list(model.module.bn1.parameters()) +
88+
list(model.module.layer1.parameters()) +
89+
list(model.module.layer2.parameters()) +
90+
list(model.module.layer3.parameters()) +
91+
list(model.module.layer4.parameters()))
10692
last_params = list(model.module.aspp.parameters())
10793
optimizer = optim.SGD([
10894
{'params': filter(lambda p: p.requires_grad, backbone_params)},

0 commit comments

Comments
 (0)