Skip to content

Commit

Permalink
Adding evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
eldakms committed Mar 28, 2017
1 parent 0ee09cf commit 6b547e1
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 0 deletions.
1 change: 1 addition & 0 deletions bindings/python/cntk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .ops import *
from .device import *
from .train import *
from .eval import *
from .learners import *
from .losses import *
from .metrics import *
Expand Down
7 changes: 7 additions & 0 deletions bindings/python/cntk/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================

from .evaluator import *
91 changes: 91 additions & 0 deletions bindings/python/cntk/eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================

from .. import cntk_py
from ..device import use_default_device
from cntk.internal import sanitize_var_map, sanitize_function, typemap
from ..io import MinibatchData

__doc__= '''\
An evaluator provides functionality to evaluate minibatches against the specified evaluation function.
'''

class Evaluator(cntk_py.Evaluator):
'''
Class for evaluation of minibatches against the specified evaluation function.
Args:
eval_function (:class:`~cntk.ops.functions.Function`): evaluation function.
progress_writers (list): optionally, list of progress writers from :mod:`cntk.utils` to track
training progress.
'''

def __init__(self, eval_function, progress_writers=None):
if eval_function is not None:
eval_function = sanitize_function(eval_function)

if progress_writers is None:
progress_writers = []
elif not isinstance(progress_writers, list):
progress_writers = [progress_writers]

evaluator = cntk_py.create_evaluator(eval_function, progress_writers)
# transplant into this class instance
self.__dict__ = evaluator.__dict__

def test_minibatch(self, arguments, device=None):
'''
Test the evaluation function on the specified batch of samples.
Args:
arguments: maps variables to their
input data. The interpretation depends on the input type:
* `dict`: keys are input variable or names, and values are the input data.
See :meth:`~cntk.ops.functions.Function.forward` for details on passing input data.
* any other type: if node has an unique input, ``arguments`` is mapped to this input.
For nodes with more than one input, only `dict` is allowed.
In both cases, every sample in the data will be interpreted
as a new sequence. To mark samples as continuations of the
previous sequence, specify ``arguments`` as `tuple`: the
first element will be used as ``arguments``, and the second one will
be used as a list of bools, denoting whether a sequence is a new
one (`True`) or a continuation of the previous one (`False`).
Data should be either NumPy arrays or a
:class:`~cntk.io.MinibatchData` instance.
device (:class:`~cntk.device.DeviceDescriptor`): the device descriptor that
contains the type and id of the device on which the computation is
to be performed.
Note:
See :meth:`~cntk.ops.functions.Function.forward` for examples on
passing input data.
Returns:
`float`: the average evaluation criterion value per sample for the
tested minibatch.
'''
if not device:
device = use_default_device()

arguments = sanitize_var_map(tuple(self.evaluation_function.arguments), arguments)
return super(Evaluator, self).test_minibatch(arguments, device)

@property
@typemap
def evaluation_function(self):
'''
The evaluation function that the evaluator is using.
'''
return super(Evaluator, self).evaluation_function()

def summarize_test_progress(self):
'''
Updates the progress writers with the summary of test progress since start and resets the internal
accumulators.
'''
return super(Evaluator, self).summarize_test_progress()
5 changes: 5 additions & 0 deletions bindings/python/cntk/eval/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.

# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
37 changes: 37 additions & 0 deletions bindings/python/cntk/eval/tests/evaluator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft. All rights reserved.

# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================

import os
import math
import numpy as np
from cntk import *
from cntk.ops.tests.ops_test_utils import cntk_device
from ..evaluator import *
from cntk.metrics import classification_error
from cntk import parameter, input, times, plus, reduce_sum, Axis, cntk_py
import pytest

def test_eval():
input_dim = 2
proj_dim = 2

x = input(shape=(input_dim,))
W = parameter(shape=(input_dim, proj_dim), init=[[1, 0], [0, 1]])
B = parameter(shape=(proj_dim,), init=[[0, 1]])
t = times(x, W)
z = t + B

labels = input(shape=(proj_dim,))
pe = classification_error(z, labels)

tester = Evaluator(pe)

x_value = [[0, 1], [2, 2]]
label_value = [[0, 1], [1, 0]]
arguments = {x: x_value, labels: label_value}
eval_error = tester.test_minibatch(arguments)

assert np.allclose(eval_error, .5)
1 change: 1 addition & 0 deletions bindings/python/doc/apireference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ Python API Reference
Metrics <cntk.metrics.rst>
Ops <cntk.ops.rst>
Train <cntk.train>
Eval <cntk.eval>
Module reference <modules>

0 comments on commit 6b547e1

Please sign in to comment.