File tree Expand file tree Collapse file tree 1 file changed +12
-3
lines changed Expand file tree Collapse file tree 1 file changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -833,11 +833,12 @@ def apply(inputs):
833
833
return apply
834
834
835
835
836
- def Head (num_classes = 1000 , name = None ):
836
+ def Head (num_classes = 1000 , classifier_activation = None , name = None ):
837
837
"""Implementation of classification head of RegNet.
838
838
839
839
Args:
840
840
num_classes: number of classes for Dense layer
841
+ classifier_activation: activation function for the Dense layer
841
842
name: name prefix
842
843
843
844
Returns:
@@ -848,7 +849,11 @@ def Head(num_classes=1000, name=None):
848
849
849
850
def apply (x ):
850
851
x = layers .GlobalAveragePooling2D (name = name + "_head_gap" )(x )
851
- x = layers .Dense (num_classes , name = name + "head_dense" )(x )
852
+ x = layers .Dense (
853
+ num_classes ,
854
+ activation = classifier_activation ,
855
+ name = name + "head_dense" ,
856
+ )(x )
852
857
return x
853
858
854
859
return apply
@@ -977,8 +982,12 @@ def RegNet(
977
982
in_channels = out_channels
978
983
979
984
if include_top :
980
- x = Head (num_classes = classes )(x )
981
985
imagenet_utils .validate_activation (classifier_activation , weights )
986
+ x = Head (
987
+ num_classes = classes ,
988
+ classifier_activation = classifier_activation ,
989
+ name = model_name ,
990
+ )(x )
982
991
983
992
else :
984
993
if pooling == "avg" :
You can’t perform that action at this time.
0 commit comments