From 67613fa13f8211a4d80602f16c9d317895a03dea Mon Sep 17 00:00:00 2001 From: Chiyuan Zhang Date: Sat, 15 Oct 2016 17:34:37 -0400 Subject: [PATCH] Doc (#3513) * add infrastructure for symbol doctest * add regression test demo for FullyConnected * move utils to test_utils.py * fix lint error --- python/mxnet/__init__.py | 1 + python/mxnet/ndarray.py | 5 + python/mxnet/symbol.py | 5 + python/mxnet/symbol_doc.py | 172 +++++++++++++++++++++----------- python/mxnet/test_utils.py | 67 ++++++++++++- src/operator/activation.cc | 4 +- src/operator/fully_connected.cc | 5 +- tests/python/doctest/run.py | 39 ++++++++ 8 files changed, 233 insertions(+), 65 deletions(-) create mode 100644 tests/python/doctest/run.py diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index a2c9f7e02cea..85cec9a3d557 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -11,6 +11,7 @@ # use mx.sym as short for symbol from . import symbol as sym from . import symbol +from . import symbol_doc from . import io from . import recordio from . import operator diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 85a0cd4f5283..7c0694a6f9bc 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -94,6 +94,11 @@ def __init__(self, handle, writable=True): self.handle = handle self.writable = writable + def __repr__(self): + shape_info = 'x'.join(['%d' % x for x in self.shape]) + return '<%s %s @%s>' % (self.__class__.__name__, + shape_info, self.context) + def __del__(self): check_call(_LIB.MXNDArrayFree(self.handle)) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index e244ecb59cfd..6785972e760b 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -35,6 +35,11 @@ def __init__(self, handle): """ self.handle = handle + def __repr__(self): + """Get a string representation of the symbol.""" + return '<%s %s>' % (self.__class__.__name__, + self.name) + def __add__(self, other): if isinstance(other, Symbol): return _internal._Plus(self, other) diff --git a/python/mxnet/symbol_doc.py b/python/mxnet/symbol_doc.py index 422098b2e460..353196fe2ac3 100644 --- a/python/mxnet/symbol_doc.py +++ b/python/mxnet/symbol_doc.py @@ -1,94 +1,144 @@ # coding: utf-8 -"""Extra symbol documents""" +"""Extra symbol documents + +Guidelines +---------- + +To add extra doc to the operator `XXX`, write a class `XXXDoc`, deriving +from the base class `SymbolDoc`, and put the extra doc as the docstring +of `XXXDoc`. + +The document added here should be Python-specific. Documents that are useful +for all language bindings should be added to the C++ side where the operator +is defined / registered. + +The code snippet in the docstring will be run using `doctest`. During running, +the environment will have access to + +- all the global names in this file (e.g. `SymbolDoc`) +- all the operators (e.g. `FullyConnected`) +- the name `test_utils` for `mxnet.test_utils` (e.g. `test_utils.reldiff`) +- the name `mxnet` (e.g. `mxnet.nd.zeros`) + +The following documents are recommended: + +- *Examples*: simple and short code snippet showing how to use this operator. + It should show typical calling examples and behaviors (e.g. maps an input + of what shape to an output of what shape). +- *Regression Test*: longer test code for the operators. We normally do not + expect the users to read those, but they will be executed by `doctest` to + ensure the behavior of each operator does not change unintentionally. +""" class SymbolDoc(object): - """The basic class""" + """The base class for attaching doc to operators.""" + + @staticmethod + def get_output_shape(sym, **input_shapes): + """Get user friendly information of the output shapes.""" + _, s_outputs, _ = sym.infer_shape(**input_shapes) + return dict(zip(sym.list_outputs(), s_outputs)) + + +class FullyConnectedDoc(SymbolDoc): + """ + Examples + -------- + Construct a fully connected operator with target dimension 512. + + >>> data = Variable('data') # or some constructed NN + >>> op = FullyConnected(data=data, + ... num_hidden=512, + ... name='FC1') + >>> op + + >>> SymbolDoc.get_output_shape(op, data=(128, 100)) + {'FC1_output': (128L, 512L)} + + A simple 3-layer MLP with ReLU activation: + + >>> net = Variable('data') + >>> for i, dim in enumerate([128, 64]): + ... net = FullyConnected(data=net, num_hidden=dim, name='FC%d' % i) + ... net = Activation(data=net, act_type='relu', name='ReLU%d' % i) + >>> # 10-class predictor (e.g. MNIST) + >>> net = FullyConnected(data=net, num_hidden=10, name='pred') + >>> net + + + Regression Test + --------------- + >>> dim_in, dim_out = (3, 4) + >>> x, w, b = test_utils.random_arrays((10, dim_in), (dim_out, dim_in), (dim_out,)) + >>> op = FullyConnected(num_hidden=dim_out, name='FC') + >>> out = test_utils.simple_forward(op, FC_data=x, FC_weight=w, FC_bias=b) + >>> # numpy implementation of FullyConnected + >>> out_np = numpy.dot(x, w.T) + b + >>> test_utils.almost_equal(out, out_np) + True + """ pass + class ConcatDoc(SymbolDoc): """ Examples -------- - >>> import mxnet as mx - >>> data = mx.nd.array(range(6)).reshape((2,1,3)) - >>> print "input shape = %s" % data.shape - >>> print "data = %s" % (data.asnumpy(), ) - input shape = (2L, 1L, 3L) - data = [[[ 0. 1. 2.]] - [[ 3. 4. 5.]]] - - >>> # concat two variables on different dimensions - >>> a = mx.sym.Variable('a') - >>> b = mx.sym.Variable('b') - >>> for dim in range(3): - ... cat = mx.sym.Concat(a, b, dim=dim) - ... exe = cat.bind(ctx=mx.cpu(), args={'a':data, 'b':data}) - ... exe.forward() - ... out = exe.outputs[0] - ... print "concat at dim = %d" % dim - ... print "shape = %s" % (out.shape, ) - ... print "results = %s" % (out.asnumpy(), ) - concat at dim = 0 - shape = (4L, 1L, 3L) - results = [[[ 0. 1. 2.]] - [[ 3. 4. 5.]] - [[ 0. 1. 2.]] - [[ 3. 4. 5.]]] - concat at dim = 1 - shape = (2L, 2L, 3L) - results = [[[ 0. 1. 2.] - [ 0. 1. 2.]] - [[ 3. 4. 5.] - [ 3. 4. 5.]]] - concat at dim = 2 - shape = (2L, 1L, 6L) - results = [[[ 0. 1. 2. 0. 1. 2.]] - [[ 3. 4. 5. 3. 4. 5.]]] + Concat two (or more) inputs along a specific dimension: + + >>> a = Variable('a') + >>> b = Variable('b') + >>> c = Concat(a, b, dim=1, name='my-concat') + >>> c + + >>> SymbolDoc.get_output_shape(c, a=(128, 10, 3, 3), b=(128, 15, 3, 3)) + {'my-concat_output': (128L, 25L, 3L, 3L)} + + Note the shape should be the same except on the dimension that is being + concatenated. """ pass -class BroadcastPlusDoc(SymbolDoc): - """add with broadcast +class BroadcastPlusDoc(SymbolDoc): + """ Examples -------- - >>> a = mx.sym.Variable('a') - >>> b = mx.sym.Variable('b') - >>> c = mx.sym.BroadcastPlus(a, b) - >>> dev = mx.cpu(); - >>> x = c.bind(dev, args={'a': mx.nd.ones((2,2)), 'b' : mx.nd.ones((2,2))}) + >>> a = Variable('a') + >>> b = Variable('b') + >>> c = broadcast_plus(a, b) + + Normal summation with matching shapes: + + >>> dev = mxnet.context.cpu(); + >>> x = c.bind(dev, args={'a': mxnet.nd.ones((2, 2)), 'b' : mxnet.nd.ones((2, 2))}) >>> x.forward() + [] >>> print x.outputs[0].asnumpy() [[ 2. 2.] [ 2. 2.]] - >>> x = c.bind(dev, args={'a': mx.nd.ones((2,2)), 'b' : mx.nd.ones((1,1))}) + + Broadcasting: + + >>> x = c.bind(dev, args={'a': mxnet.nd.ones((2, 2)), 'b' : mxnet.nd.ones((1, 1))}) >>> x.forward() + [] >>> print x.outputs[0].asnumpy() [[ 2. 2.] [ 2. 2.]] - >>> x = c.bind(dev, args={'a': mx.nd.ones((2,1)), 'b' : mx.nd.ones((1,2))}) + + >>> x = c.bind(dev, args={'a': mxnet.nd.ones((2, 1)), 'b' : mxnet.nd.ones((1, 2))}) >>> x.forward() + [] >>> print x.outputs[0].asnumpy() [[ 2. 2.] [ 2. 2.]] - >>> x = c.bind(dev, args={'a': mx.nd.ones((1,2)), 'b' : mx.nd.ones((2,1))}) + + >>> x = c.bind(dev, args={'a': mxnet.nd.ones((1, 2)), 'b' : mxnet.nd.ones((2, 1))}) >>> x.forward() + [] >>> print x.outputs[0].asnumpy() [[ 2. 2.] [ 2. 2.]] - >>> x = c.bind(dev, args={'a': mx.nd.ones((2,2,2)), 'b' : mx.nd.ones((1,2,1))} - >>> x.forward() - >>> print x.outputs[0].asnumpy() - [[[ 2. 2.] - [ 2. 2.]] - [[ 2. 2.] - [ 2. 2.]]] - >>> x = c.bind(dev, args={'a': mx.nd.ones((2,1,1)), 'b' : mx.nd.ones((2,2,2))}) - >>> x.forward() - >>> print x.outputs[0].asnumpy() - [[[ 2. 2.] - [ 2. 2.]] - [[ 2. 2.] - [ 2. 2.]]] """ diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 4ce1521ae410..96cb3bf58ba3 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1,13 +1,42 @@ # coding: utf-8 """Tools for testing.""" -# pylint: disable=invalid-name, no-member, too-many-arguments, too-many-locals, too-many-branches, too-many-statements, broad-except, line-too-long +# pylint: disable=invalid-name, no-member, too-many-arguments, too-many-locals, too-many-branches, too-many-statements, broad-except, line-too-long, unused-import from __future__ import absolute_import, print_function, division import time import numpy as np import numpy.testing as npt import mxnet as mx + +from .context import cpu, gpu +from .ndarray import array + _rng = np.random.RandomState(1234) +def default_context(): + """Get default context for regression test.""" + # _TODO: get context from environment variable to support + # testing with GPUs + return cpu() + + +def default_dtype(): + """Get default data type for regression test.""" + # _TODO: get default dtype from environment variable + return np.float32 + + +def default_numerical_threshold(): + """Get default numerical threshold for regression test.""" + # _TODO: get from env variable, different threshold might + # be needed for different device and dtype + return 1e-6 + + +def random_arrays(*shapes): + """Generate some random numpy arrays.""" + return [np.random.randn(*s).astype(default_dtype()) + for s in shapes] + def np_reduce(dat, axis, keepdims, numpy_reduce_func): """Compatible reduce for old version numpy @@ -70,6 +99,41 @@ def reldiff(a, b): return ret +def almost_equal(a, b, threshold=None): + """Test if two numpy arrays are almost equal.""" + threshold = threshold or default_numerical_threshold() + return reldiff(a, b) <= threshold + + +def simple_forward(sym, ctx=None, **inputs): + """A simple forward function for a symbol. + + Primarily used in doctest to conveniently test the function + of a symbol. Takes numpy array as inputs and outputs are + also converted to numpy arrays. + + Parameters + ---------- + ctx : Context + If None, will take the default context. + inputs : keyword arguments + Mapping each input name to a numpy array. + + Returns + ------- + The result as a numpy array. Multiple results will + be returned as a list of numpy arrays. + """ + ctx = ctx or default_context() + inputs = {k: array(v) for k, v in inputs.iteritems()} + exe = sym.bind(ctx, args=inputs) + exe.forward() + outputs = [x.asnumpy() for x in exe.outputs] + if len(outputs) == 1: + outputs = outputs[0] + return outputs + + def _parse_location(sym, location, ctx): """Parse the given location to a dictionary @@ -215,6 +279,7 @@ def random_projection(shape): # otherwise too much precision is lost in numerical gradient plain = _rng.rand(*shape) + 0.1 return plain + location = _parse_location(sym=sym, location=location, ctx=ctx) location_npy = {k:v.asnumpy() for k, v in location.items()} aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx) diff --git a/src/operator/activation.cc b/src/operator/activation.cc index 1c0ce8555ebe..17b22565f477 100644 --- a/src/operator/activation.cc +++ b/src/operator/activation.cc @@ -46,8 +46,8 @@ Operator *ActivationProp::CreateOperatorEx(Context ctx, std::vector *in_ DMLC_REGISTER_PARAMETER(ActivationParam); MXNET_REGISTER_OP_PROPERTY(Activation, ActivationProp) -.describe("Apply activation function to input." - "Softmax Activation is only available with CUDNN on GPU" +.describe("Apply activation function to input. " + "Softmax Activation is only available with CUDNN on GPU " "and will be computed at each location across channel if input is 4D.") .add_argument("data", "Symbol", "Input data to activation function.") .add_arguments(ActivationParam::__FIELDS__()); diff --git a/src/operator/fully_connected.cc b/src/operator/fully_connected.cc index 19d2cdd1a2f9..ca08884a612e 100644 --- a/src/operator/fully_connected.cc +++ b/src/operator/fully_connected.cc @@ -39,7 +39,10 @@ Operator *FullyConnectedProp::CreateOperatorEx(Context ctx, std::vector DMLC_REGISTER_PARAMETER(FullyConnectedParam); MXNET_REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedProp) -.describe("Apply matrix multiplication to input then add a bias.") +.describe(R"(Apply matrix multiplication to input then add a bias. +It maps the input of shape `(batch_size, input_dim)` to the shape of +`(batch_size, num_hidden)`. Learnable parameters include the weights +of the linear transform and an optional bias vector.)") .add_argument("data", "Symbol", "Input data to the FullyConnectedOp.") .add_argument("weight", "Symbol", "Weight matrix.") .add_argument("bias", "Symbol", "Bias parameter.") diff --git a/tests/python/doctest/run.py b/tests/python/doctest/run.py new file mode 100644 index 000000000000..93579b9fa0ee --- /dev/null +++ b/tests/python/doctest/run.py @@ -0,0 +1,39 @@ +import doctest +import logging +import mxnet + +def import_into(globs, module, names=None, error_on_overwrite=True): + """Import names from module into the globs dict. + + Parameters + ---------- + """ + mod_names = dir(module) + if names is not None: + for name in names: + assert name in mod_names, '%s not found in %s' % ( + name, module) + mod_names = names + + for name in mod_names: + if name in globs: + error_msg = 'Attempting to overwrite definition of %s' % name + if error_on_overwrite: + raise RuntimeError(error_msg) + logging.warning('%s', error_msg) + globs[name] = getattr(module, name) + + return globs + + +def test_symbols(): + globs = {'mxnet': mxnet, 'test_utils': mxnet.test_utils} + + # make sure all the operators are available + import_into(globs, mxnet.symbol) + + doctest.testmod(mxnet.symbol_doc, globs=globs) + + +if __name__ == '__main__': + test_symbols()