Skip to content

Commit

Permalink
set default as imagenet
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Aug 20, 2021
1 parent fac0230 commit 9fd144d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 25 deletions.
19 changes: 4 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
```py
pip install -U git+https://github.com/leondgarse/keras_efficientnet_v2
```
- **Define model and load pretrained weights** Parameter `pretrained` is added in value `[None, "imagenet", "imagenet21k", "imagenet21k-ft1k"]`, default is `imagenet21k-ft1k`.
- **Define model and load pretrained weights** Parameter `pretrained` is added in value `[None, "imagenet", "imagenet21k", "imagenet21k-ft1k"]`, default is `imagenet`.
```py
# Will download and load `imagenet` pretrained weights.
# Model weight is loaded with `by_name=True, skip_mismatch=True`.
Expand Down Expand Up @@ -142,24 +142,13 @@
from tensorflow import keras
from keras_efficientnet_v2 import progressive_train_test

num_classes = 10
ev2_s = keras_efficientnet_v2.EfficientNetV2("s", input_shape=(None, None, 3), num_classes=0)
out = ev2_s.output

nn = keras.layers.GlobalAveragePooling2D(name="avg_pool")(out)
nn = keras.layers.Dropout(0.1)(nn)
nn = keras.layers.Dense(num_classes, activation="softmax", name="predictions", dtype="float32")(nn)
model = keras.models.Model(ev2_s.inputs[0], nn)

lr_scheduler = None
optimizer = "adam"
loss = "categorical_crossentropy"
model.compile(loss=loss, optimizer=optimizer, metrics=["accuracy"])
model = keras_efficientnet_v2.EfficientNetV2S(input_shape=(None, None, 3), num_classes=10, classifier_activation='softmax', dropout=0.1)
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

hhs = progressive_train_test.progressive_with_dropout_randaug(
model,
data_name="cifar10",
lr_scheduler=lr_scheduler,
lr_scheduler=None,
total_epochs=36,
batch_size=64,
dropout_layer=-2,
Expand Down
20 changes: 10 additions & 10 deletions keras_efficientnet_v2/efficientnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def EfficientNetV2(
first_strides=2,
drop_connect_rate=0,
classifier_activation="softmax",
pretrained="imagenet21k-ft1k",
pretrained="imagenet",
model_name="EfficientNetV2",
kwargs=None, # Not used, just recieving parameter
):
Expand Down Expand Up @@ -253,7 +253,7 @@ def EfficientNetV2(
nn = GlobalAveragePooling2D(name="avg_pool")(nn)
if dropout > 0 and dropout < 1:
nn = Dropout(dropout)(nn)
nn = Dense(num_classes, activation=classifier_activation, name="predictions")(nn)
nn = Dense(num_classes, activation=classifier_activation, dtype="float32", name="predictions")(nn)

model = Model(inputs=inputs, outputs=nn, name=model_name)
reload_model_weights(model, model_type, pretrained)
Expand All @@ -279,39 +279,39 @@ def reload_model_weights(model, model_type, pretrained="imagenet"):
model.load_weights(pretrained_model, by_name=True, skip_mismatch=True)


def EfficientNetV2B0(input_shape=(224, 224, 3), num_classes=1000, dropout=0.2, classifier_activation="softmax", pretrained="imagenet21k-ft1k", **kwargs):
def EfficientNetV2B0(input_shape=(224, 224, 3), num_classes=1000, dropout=0.2, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="b0", model_name="EfficientNetV2B0", **locals(), **kwargs)


def EfficientNetV2B1(input_shape=(240, 240, 3), num_classes=1000, dropout=0.2, classifier_activation="softmax", pretrained="imagenet21k-ft1k", **kwargs):
def EfficientNetV2B1(input_shape=(240, 240, 3), num_classes=1000, dropout=0.2, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="b1", model_name="EfficientNetV2B1", **locals(), **kwargs)


def EfficientNetV2B2(input_shape=(260, 260, 3), num_classes=1000, dropout=0.3, classifier_activation="softmax", pretrained="imagenet21k-ft1k", **kwargs):
def EfficientNetV2B2(input_shape=(260, 260, 3), num_classes=1000, dropout=0.3, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="b2", model_name="EfficientNetV2B2", **locals(), **kwargs)


def EfficientNetV2B3(input_shape=(300, 300, 3), num_classes=1000, dropout=0.3, classifier_activation="softmax", pretrained="imagenet21k-ft1k", **kwargs):
def EfficientNetV2B3(input_shape=(300, 300, 3), num_classes=1000, dropout=0.3, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="b3", model_name="EfficientNetV2B3", **locals(), **kwargs)


def EfficientNetV2T(input_shape=(320, 320, 3), num_classes=1000, dropout=0.2, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="t", model_name="EfficientNetV2T", **locals(), **kwargs)


def EfficientNetV2S(input_shape=(384, 384, 3), num_classes=1000, dropout=0.2, classifier_activation="softmax", pretrained="imagenet21k-ft1k", **kwargs):
def EfficientNetV2S(input_shape=(384, 384, 3), num_classes=1000, dropout=0.2, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="s", model_name="EfficientNetV2S", **locals(), **kwargs)


def EfficientNetV2M(input_shape=(480, 480, 3), num_classes=1000, dropout=0.3, classifier_activation="softmax", pretrained="imagenet21k-ft1k", **kwargs):
def EfficientNetV2M(input_shape=(480, 480, 3), num_classes=1000, dropout=0.3, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="m", model_name="EfficientNetV2M", **locals(), **kwargs)


def EfficientNetV2L(input_shape=(480, 480, 3), num_classes=1000, dropout=0.4, classifier_activation="softmax", pretrained="imagenet21k-ft1k", **kwargs):
def EfficientNetV2L(input_shape=(480, 480, 3), num_classes=1000, dropout=0.4, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="l", model_name="EfficientNetV2L", **locals(), **kwargs)


def EfficientNetV2XL(input_shape=(512, 512, 3), num_classes=1000, dropout=0.4, classifier_activation="softmax", pretrained="imagenet21k-ft1k", **kwargs):
def EfficientNetV2XL(input_shape=(512, 512, 3), num_classes=1000, dropout=0.4, classifier_activation="softmax", pretrained="imagenet", **kwargs):
return EfficientNetV2(model_type="xl", model_name="EfficientNetV2XL", **locals(), **kwargs)


Expand Down

0 comments on commit 9fd144d

Please sign in to comment.