Skip to content

Commit

Permalink
Enable the reporting of cross-entropy or nll loss value when training…
Browse files Browse the repository at this point in the history
… CNN network using the models defined by example/image-classification (apache#9805)

* Enable the reporting of cross-entropy or nll loss value during training

* Set the default value of loss as a '' to avoid a Python runtime issue when loss argument is not set
  • Loading branch information
juliusshufan authored and szha committed Feb 22, 2018
1 parent 6d2b6e3 commit e1d8c66
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def add_fit_args(parser):
help='load the model on an epoch using the model-load-prefix')
train.add_argument('--top-k', type=int, default=0,
help='report the top-k accuracy. 0 means no report.')
train.add_argument('--loss', type=str, default='',
help='show the cross-entropy or nll loss. ce strands for cross-entropy, nll-loss stands for likelihood loss')
train.add_argument('--test-io', type=int, default=0,
help='1 means test reading speed without training')
train.add_argument('--dtype', type=str, default='float32',
Expand Down Expand Up @@ -260,6 +262,23 @@ def fit(args, network, data_loader, **kwargs):
eval_metrics.append(mx.metric.create(
'top_k_accuracy', top_k=args.top_k))

supported_loss = ['ce', 'nll_loss']
if len(args.loss) > 0:
# ce or nll loss is only applicable to softmax output
loss_type_list = args.loss.split(',')
if 'softmax_output' in network.list_outputs():
for loss_type in loss_type_list:
loss_type = loss_type.strip()
if loss_type == 'nll':
loss_type = 'nll_loss'
if loss_type not in supported_loss:
logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' \
'negative likelihood loss is supported!')
else:
eval_metrics.append(mx.metric.create(loss_type))
else:
logging.warning("The output is not softmax_output, loss argument will be skipped!")

# callbacks that run after each batch
batch_end_callbacks = [mx.callback.Speedometer(
args.batch_size, args.disp_batches)]
Expand Down

0 comments on commit e1d8c66

Please sign in to comment.