Skip to content

Commit e269817

Browse files
committed
Passing the right activation on quantization.
1 parent 72cecb1 commit e269817

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

torchvision/models/mobilenetv3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4):
2626
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
2727
super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
2828
self.relu = self.activation
29+
delattr(self, 'activation')
2930
warnings.warn(
3031
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
3132

torchvision/models/quantization/mobilenetv3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
class QuantizableSqueezeExcitation(SElayer):
2121
def __init__(self, *args: Any, **kwargs: Any) -> None:
22+
kwargs["scale_activation"]=nn.Hardswish
2223
super().__init__(*args, **kwargs)
2324
self.skip_mul = nn.quantized.FloatFunctional()
2425

@@ -80,11 +81,12 @@ def _load_weights(
8081
model: QuantizableMobileNetV3,
8182
model_url: Optional[str],
8283
progress: bool,
84+
strict: bool
8385
) -> None:
8486
if model_url is None:
8587
raise ValueError("No checkpoint is available for {}".format(arch))
8688
state_dict = load_state_dict_from_url(model_url, progress=progress)
87-
model.load_state_dict(state_dict)
89+
model.load_state_dict(state_dict, strict=strict)
8890

8991

9092
def _mobilenet_v3_model(
@@ -108,13 +110,13 @@ def _mobilenet_v3_model(
108110
torch.quantization.prepare_qat(model, inplace=True)
109111

110112
if pretrained:
111-
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)
113+
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, False)
112114

113115
torch.quantization.convert(model, inplace=True)
114116
model.eval()
115117
else:
116118
if pretrained:
117-
_load_weights(arch, model, model_urls.get(arch, None), progress)
119+
_load_weights(arch, model, model_urls.get(arch, None), progress, True)
118120

119121
return model
120122

0 commit comments

Comments
 (0)