Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Backport of #16827, #16791 and #16888 to 1.6 branch #16901

Merged
merged 3 commits into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add evaluation_loss to the estimator base class. (#16888)
* Add evaluation_loss to the estimator base class.

* Update the base estimator class to support the separate evaluation loss.

* Add evaluation loss to the base estimator class.

* Add unittest for evaluation loss in the test_evaluation function

* Update estimator.py

* Update estimator.py
  • Loading branch information
liuzh47 authored and ptrendx committed Nov 25, 2019
commit 1232c75207bca805546bbf81ae2aba3b33bf8908
11 changes: 9 additions & 2 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class Estimator(object):
Trainer to apply optimizer on network parameters.
context : Context or list of Context
Device(s) to run the training on.
evaluation_loss: gluon.loss.loss
Loss (objective) function to calculate during evaluation. If set evaluation_loss
None, it will use the same loss function as self.loss

"""

Expand All @@ -85,12 +88,16 @@ def __init__(self, net,
metrics=None,
initializer=None,
trainer=None,
context=None):
context=None,
evaluation_loss=None):
self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(metrics)
self._add_default_training_metrics()
self._add_validation_metrics()
self.evaluation_loss = self.loss
if evaluation_loss is not None:
self.evaluation_loss = self._check_loss(evaluation_loss)

self.logger = logging.Logger(name='Estimator', level=logging.INFO)
self.logger.addHandler(logging.StreamHandler(sys.stdout))
Expand Down Expand Up @@ -228,7 +235,7 @@ def evaluate_batch(self,
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@ def test_validation():
ctx = mx.cpu()
loss = gluon.loss.L2Loss()
acc = mx.metric.Accuracy()
evaluation_loss = gluon.loss.L1Loss()
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
metrics=acc,
trainer=trainer,
context=ctx)
context=ctx,
evaluation_loss=evaluation_loss)
# Input dataloader
est.fit(train_data=dataloader,
val_data=dataloader,
Expand Down