Skip to content

Commit

Permalink
Convert to new programing API (apache#4307)
Browse files Browse the repository at this point in the history
  • Loading branch information
howard0su authored and piiswrong committed Dec 29, 2016
1 parent e18ff28 commit a9ec282
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def add_fit_args(parser):
help='show progress for every n batches')
train.add_argument('--model-prefix', type=str,
help='model prefix')
parser.add_argument('--monitor', dest='monitor', type=int, default=0,
help='log network parameters every N iters if larger than 0')
train.add_argument('--load-epoch', type=int,
help='load the model on an epoch using the model-load-prefix')
train.add_argument('--top-k', type=int, default=0,
Expand Down Expand Up @@ -134,23 +136,23 @@ def fit(args, network, data_loader, **kwargs):
lr, lr_scheduler = _get_lr_scheduler(args, kv)

# create model
model = mx.model.FeedForward(
ctx = devs,
symbol = network,
begin_epoch = args.load_epoch if args.load_epoch else 0,
num_epoch = args.num_epochs,
arg_params = arg_params,
aux_params = aux_params,
learning_rate = lr,
lr_scheduler = lr_scheduler,
momentum = args.mom,
wd = args.wd,
optimizer = args.optimizer,
initializer = mx.init.Xavier(
rnd_type='gaussian', factor_type="in", magnitude=2)
# initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
model = mx.mod.Module(
context = devs,
symbol = network
)

lr_scheduler = lr_scheduler
optimizer_params = {
'learning_rate': lr,
'momentum' : args.mom,
'wd' : args.wd}

monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None

initializer = mx.init.Xavier(
rnd_type='gaussian', factor_type="in", magnitude=2)
# initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),

# evaluation metrices
eval_metrics = ['accuracy']
if args.top_k > 0:
Expand All @@ -163,10 +165,17 @@ def fit(args, network, data_loader, **kwargs):
batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]

# run
model.fit(
X = train,
model.fit(train,
begin_epoch = args.load_epoch if args.load_epoch else 0,
num_epoch = args.num_epochs,
eval_data = val,
eval_metric = eval_metrics,
kvstore = kv,
optimizer = args.optimizer,
optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
batch_end_callback = batch_end_callbacks,
epoch_end_callback = checkpoint)
epoch_end_callback = checkpoint,
monitor = monitor)

0 comments on commit a9ec282

Please sign in to comment.