Skip to content

Commit

Permalink
Revert ml-decoder changes to model factory and train script
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Mar 21, 2022
1 parent 72b5716 commit d98aa47
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
5 changes: 0 additions & 5 deletions timm/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def create_model(
scriptable=None,
exportable=None,
no_jit=None,
use_ml_decoder_head=False,
**kwargs):
"""Create a model
Expand Down Expand Up @@ -81,10 +80,6 @@ def create_model(
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, **kwargs)

if use_ml_decoder_head:
from timm.models.layers.ml_decoder import add_ml_decoder_head
model = add_ml_decoder_head(model)

if checkpoint_path:
load_checkpoint(model, checkpoint_path)

Expand Down
4 changes: 1 addition & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
help='input batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='validation batch size override (default: None)')
parser.add_argument('--use-ml-decoder-head', type=int, default=0)

# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
Expand Down Expand Up @@ -380,8 +379,7 @@ def main():
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
use_ml_decoder_head=args.use_ml_decoder_head)
checkpoint_path=args.initial_checkpoint)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
Expand Down

0 comments on commit d98aa47

Please sign in to comment.