Skip to content

Commit 00d4889

Browse files
committed
Fix Head params to accept classifier_activation
1 parent bce4ac9 commit 00d4889

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

keras/applications/regnet.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,11 +833,12 @@ def apply(inputs):
833833
return apply
834834

835835

836-
def Head(num_classes=1000, name=None):
836+
def Head(num_classes=1000, classifier_activation=None, name=None):
837837
"""Implementation of classification head of RegNet.
838838
839839
Args:
840840
num_classes: number of classes for Dense layer
841+
classifier_activation: activation function for the Dense layer
841842
name: name prefix
842843
843844
Returns:
@@ -848,7 +849,11 @@ def Head(num_classes=1000, name=None):
848849

849850
def apply(x):
850851
x = layers.GlobalAveragePooling2D(name=name + "_head_gap")(x)
851-
x = layers.Dense(num_classes, name=name + "head_dense")(x)
852+
x = layers.Dense(
853+
num_classes,
854+
activation=classifier_activation,
855+
name=name + "head_dense",
856+
)(x)
852857
return x
853858

854859
return apply
@@ -977,8 +982,12 @@ def RegNet(
977982
in_channels = out_channels
978983

979984
if include_top:
980-
x = Head(num_classes=classes)(x)
981985
imagenet_utils.validate_activation(classifier_activation, weights)
986+
x = Head(
987+
num_classes=classes,
988+
classifier_activation=classifier_activation,
989+
name=model_name,
990+
)(x)
982991

983992
else:
984993
if pooling == "avg":

0 commit comments

Comments
 (0)