19
19
20
20
class QuantizableSqueezeExcitation (SElayer ):
21
21
def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
22
+ kwargs ["scale_activation" ]= nn .Hardswish
22
23
super ().__init__ (* args , ** kwargs )
23
24
self .skip_mul = nn .quantized .FloatFunctional ()
24
25
@@ -80,11 +81,12 @@ def _load_weights(
80
81
model : QuantizableMobileNetV3 ,
81
82
model_url : Optional [str ],
82
83
progress : bool ,
84
+ strict : bool
83
85
) -> None :
84
86
if model_url is None :
85
87
raise ValueError ("No checkpoint is available for {}" .format (arch ))
86
88
state_dict = load_state_dict_from_url (model_url , progress = progress )
87
- model .load_state_dict (state_dict )
89
+ model .load_state_dict (state_dict , strict = strict )
88
90
89
91
90
92
def _mobilenet_v3_model (
@@ -108,13 +110,13 @@ def _mobilenet_v3_model(
108
110
torch .quantization .prepare_qat (model , inplace = True )
109
111
110
112
if pretrained :
111
- _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress )
113
+ _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress , False )
112
114
113
115
torch .quantization .convert (model , inplace = True )
114
116
model .eval ()
115
117
else :
116
118
if pretrained :
117
- _load_weights (arch , model , model_urls .get (arch , None ), progress )
119
+ _load_weights (arch , model , model_urls .get (arch , None ), progress , True )
118
120
119
121
return model
120
122
0 commit comments