diff --git a/docs/api/python/gluon.md b/docs/api/python/gluon.md index 3149deb50d53..4748a1a64bfb 100644 --- a/docs/api/python/gluon.md +++ b/docs/api/python/gluon.md @@ -63,7 +63,7 @@ in Python and then deploy with symbolic graph in C++ and Scala. .. automethod:: __call__ .. autoclass:: mxnet.gluon.nn.Sequential :members: -.. autoclass:: mxnet.gluon.nn.HSequential +.. autoclass:: mxnet.gluon.nn.HybridSequential :members: ``` diff --git a/example/gluon/resnet.py b/example/gluon/resnet.py index 06ec21dfd224..44517eaf15f2 100644 --- a/example/gluon/resnet.py +++ b/example/gluon/resnet.py @@ -134,7 +134,7 @@ def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs): self.bn0 = nn.BatchNorm(in_channels=filters[0]) self.pool0 = nn.MaxPool2D(3, 2, 1) - self.body = nn.HSequential() + self.body = nn.HybridSequential() in_channels = filters[0] for i in range(len(layers)): stride = 1 if i == 0 else 2 @@ -146,7 +146,7 @@ def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs): self.dense1 = nn.Dense(classes, in_units=filters[-1]) def _make_layer(self, block, layers, filters, stride, in_channels=0): - layer = nn.HSequential() + layer = nn.HybridSequential() layer.add(block(filters, stride, True, in_channels=in_channels)) for i in range(layers-1): layer.add(block(filters, 1, False, in_channels=filters)) @@ -248,7 +248,7 @@ def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs): self.bn0 = nn.BatchNorm(in_channels=filters[0]) self.pool0 = nn.MaxPool2D(3, 2, 1) - self.body = nn.HSequential() + self.body = nn.HybridSequential() in_channels = filters[0] for i in range(len(layers)): stride = 1 if i == 0 else 2 @@ -261,7 +261,7 @@ def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs): self.dense1 = nn.Dense(classes, in_units=in_channels) def _make_layer(self, block, layers, filters, stride, in_channels=0): - layer = nn.HSequential() + layer = nn.HybridSequential() layer.add(block(filters, stride, True, in_channels=in_channels)) for i in range(layers-1): layer.add(block(filters, 1, False, in_channels=filters)) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 5d13aa09029d..bd072e7f60f2 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -2,7 +2,7 @@ # pylint: disable= arguments-differ """Base container class for all neural network models.""" -from .. import symbol, ndarray +from .. import symbol, ndarray, initializer from ..symbol import Symbol from ..ndarray import NDArray from .. import name as _name @@ -99,7 +99,7 @@ class Block(object): class Model(Block): def __init__(self, **kwargs): - super(Net, self).__init__(**kwargs) + super(Model, self).__init__(**kwargs) # use name_scope to give child Blocks appropriate names. # It also allows sharing Parameters between Blocks recursively. with self.name_scope(): @@ -110,6 +110,11 @@ def forward(self, x): x = F.relu(self.dense0(x)) return F.relu(self.dense1(x)) + model = Model() + model.initialize(ctx=mx.cpu(0)) + model(F.zeros((10, 10), ctx=mx.cpu(0))) + + Child `Block`s assigned this way will be registered and `collect_params` will collect their Parameters recursively. @@ -124,7 +129,7 @@ def forward(self, x): if you want `dense1` to share `dense0`'s weights, you can do:: dense0 = nn.Dense(20) - dense1 = nn.Dense(20, params=dense1.collect_params()) + dense1 = nn.Dense(20, params=dense0.collect_params()) """ def __init__(self, prefix=None, params=None): self._prefix, self._params = _BlockScope.create(prefix, params, self._alias()) @@ -181,6 +186,13 @@ def register_child(self, block): attributes will be registered automatically.""" self._children.append(block) + def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False): + """Initialize `Parameter`s of this Block and its children. + + Equivalent to `block.collect_params().initialize(...)` + """ + self.collect_params().initialize(init, ctx, verbose) + def hybridize(self, active=True): """Activates or deactivates `HybridBlock`s recursively. Has no effect on non-hybrid children. @@ -250,7 +262,7 @@ def register_child(self, block): if isinstance(block, Sequential): raise ValueError( "Children of HybridBlock must also be HybridBlock. " \ - "Please use HSequential instead of Sequential.") + "Please use HybridSequential instead of Sequential.") raise ValueError( "Children of HybridBlock must also be HybridBlock, " \ "but %s has type %s."%(str(block), str(type(block)))) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index b1234f9d9e16..3bd590f1d5de 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -29,7 +29,7 @@ def forward(self, x): return x -class HSequential(HybridBlock): +class HybridSequential(HybridBlock): """Stack `HybridBlock`s sequentially. Example:: @@ -41,7 +41,7 @@ class HSequential(HybridBlock): net.add(Dense(20)) """ def __init__(self, prefix=None, params=None): - super(HSequential, self).__init__(prefix=prefix, params=params) + super(HybridSequential, self).__init__(prefix=prefix, params=params) def add(self, block): """Add block on top of the stack.""" @@ -97,16 +97,18 @@ class Dense(HybridBlock): the output would have shape `(batch_size, units)`. """ def __init__(self, units, activation=None, use_bias=True, - weight_initializer=None, bias_initializer=None, + weight_initializer=None, bias_initializer='zeros', in_units=0, **kwargs): super(Dense, self).__init__(**kwargs) with self.name_scope(): self._units = units self.weight = self.params.get('weight', shape=(units, in_units), - init=weight_initializer) + init=weight_initializer, + allow_deferred_init=True) if use_bias: self.bias = self.params.get('bias', shape=(units,), - init=bias_initializer) + init=bias_initializer, + allow_deferred_init=True) else: self.bias = None if activation is not None: @@ -133,6 +135,7 @@ class Activation(HybridBlock): name of activation function to use. See :func:`~mxnet.ndarray.Activation` for available choices. + Input shape: Arbitrary. @@ -210,6 +213,13 @@ class BatchNorm(HybridBlock): Number of channels (feature maps) in input data. If not specified, initialization will be defered to the first time `forward` is called and `in_channels` will be inferred from the shape of input data. + + + Input shape: + Arbitrary. + + Output shape: + Same shape as input. """ def __init__(self, axis=1, momentum=0.9, epsilon=1e-3, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', @@ -220,15 +230,19 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-3, center=True, scale=True, 'fix_gamma': not center} self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(in_channels,), init=gamma_initializer) + shape=(in_channels,), init=gamma_initializer, + allow_deferred_init=True) self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(in_channels,), init=beta_initializer) + shape=(in_channels,), init=beta_initializer, + allow_deferred_init=True) self.running_mean = self.params.get('running_mean', grad_req='null', shape=(in_channels,), - init=running_mean_initializer) + init=running_mean_initializer, + allow_deferred_init=True) self.running_var = self.params.get('running_var', grad_req='null', shape=(in_channels,), - init=running_variance_initializer) + init=running_variance_initializer, + allow_deferred_init=True) def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): return F.BatchNorm(x, gamma, beta, running_mean, running_var, **self._kwargs) @@ -246,6 +260,13 @@ class LeakyReLU(HybridBlock): ---------- alpha : float slope coefficient for the negative half axis. Must be >= 0. + + + Input shape: + Arbitrary. + + Output shape: + Same shape as input. """ def __init__(self, alpha, **kwargs): super(LeakyReLU, self).__init__(**kwargs) @@ -284,7 +305,24 @@ def __init__(self, input_dim, output_dim, dtype='float32', self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim, 'dtype': dtype} self.weight = self.params.get('weight', shape=(input_dim, output_dim), - init=weight_initializer) + init=weight_initializer, + allow_deferred_init=True) def hybrid_forward(self, F, x, weight): return F.Embedding(x, weight, **self._kwargs) + + +class Flatten(HybridBlock): + """Flattens the input to two dimensional. + + Input shape: + Arbitrary shape `(N, a, b, c, ...)` + + Output shape: + 2D tensor with shape: `(N, a*b*c...)` + """ + def __init__(self, **kwargs): + super(Flatten, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + return x.reshape((0, -1)) diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 3449a160cee8..86ae302f9e31 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -4,6 +4,8 @@ from ..block import HybridBlock from ... import symbol from ...base import numeric_types +from .basic_layers import Activation + def _infer_weight_shape(op_name, data_shape, kwargs): op = getattr(symbol, op_name) @@ -62,7 +64,7 @@ class _Conv(HybridBlock): """ def __init__(self, channels, kernel_size, strides, padding, dilation, groups, layout, in_channels=0, activation=None, use_bias=True, - weight_initializer=None, bias_initializer=None, + weight_initializer=None, bias_initializer='zeros', op_name='Convolution', prefix=None, params=None, **kwargs): super(_Conv, self).__init__(prefix=prefix, params=params) with self.name_scope(): @@ -86,10 +88,12 @@ def __init__(self, channels, kernel_size, strides, padding, dilation, dshape[layout.find('C')] = in_channels wshapes = _infer_weight_shape(op_name, dshape, self._kwargs) self.weight = self.params.get('weight', shape=wshapes[1], - init=weight_initializer) + init=weight_initializer, + allow_deferred_init=True) if use_bias: self.bias = self.params.get('bias', shape=wshapes[2], - init=bias_initializer) + init=bias_initializer, + allow_deferred_init=True) else: self.bias = None @@ -163,11 +167,11 @@ class Conv1D(_Conv): Initializer for the bias vector. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 3D array of shape (batch_size, in_channels, width) if `layout` is `NCW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 3D array of shape (batch_size, channels, out_width) if `layout` is `NCW`. out_width is calculated as:: @@ -176,7 +180,7 @@ class Conv1D(_Conv): """ def __init__(self, channels, kernel_size, strides=1, padding=0, dilation=1, groups=1, layout='NCW', activation=None, use_bias=True, - weight_initializer=None, bias_initializer=None, + weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs): if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,) @@ -240,11 +244,11 @@ class Conv2D(_Conv): Initializer for the bias vector. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, in_channels, height, width) if `layout` is `NCHW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. @@ -256,7 +260,7 @@ class Conv2D(_Conv): def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, layout='NCHW', activation=None, use_bias=True, weight_initializer=None, - bias_initializer=None, in_channels=0, **kwargs): + bias_initializer='zeros', in_channels=0, **kwargs): if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,)*2 assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints" @@ -319,11 +323,11 @@ class Conv3D(_Conv): Initializer for the bias vector. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 5D array of shape (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. @@ -336,7 +340,7 @@ class Conv3D(_Conv): """ def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, layout='NCDHW', activation=None, - use_bias=True, weight_initializer=None, bias_initializer=None, + use_bias=True, weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs): if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,)*3 @@ -400,11 +404,11 @@ class Conv1DTranspose(_Conv): Initializer for the bias vector. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 3D array of shape (batch_size, in_channels, width) if `layout` is `NCW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 3D array of shape (batch_size, channels, out_width) if `layout` is `NCW`. @@ -414,7 +418,7 @@ class Conv1DTranspose(_Conv): """ def __init__(self, channels, kernel_size, strides=1, padding=0, output_padding=0, dilation=1, groups=1, layout='NCW', activation=None, use_bias=True, - weight_initializer=None, bias_initializer=None, + weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs): if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,) @@ -484,11 +488,11 @@ class Conv2DTranspose(_Conv): Initializer for the bias vector. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, in_channels, height, width) if `layout` is `NCHW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. @@ -500,7 +504,7 @@ class Conv2DTranspose(_Conv): def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0), output_padding=(0, 0), dilation=(1, 1), groups=1, layout='NCHW', activation=None, use_bias=True, weight_initializer=None, - bias_initializer=None, in_channels=0, **kwargs): + bias_initializer='zeros', in_channels=0, **kwargs): if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,)*2 if isinstance(output_padding, numeric_types): @@ -569,11 +573,11 @@ class Conv3DTranspose(_Conv): Initializer for the bias vector. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 5D array of shape (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. out_depth, out_height and out_width are calculated as:: @@ -585,7 +589,7 @@ class Conv3DTranspose(_Conv): def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), output_padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, layout='NCDHW', activation=None, use_bias=True, weight_initializer=None, - bias_initializer=None, in_channels=0, **kwargs): + bias_initializer='zeros', in_channels=0, **kwargs): if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,)*3 if isinstance(output_padding, numeric_types): @@ -640,11 +644,11 @@ class MaxPool1D(_Pooling): When True, will use ceil instead of floor to compute the output shape. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 3D array of shape (batch_size, channels, width) if `layout` is `NCW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 3D array of shape (batch_size, channels, out_width) if `layout` is `NCW`. @@ -687,11 +691,11 @@ class MaxPool2D(_Pooling): When True, will use ceil instead of floor to compute the output shape. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. @@ -736,11 +740,11 @@ class MaxPool3D(_Pooling): When True, will use ceil instead of floor to compute the output shape. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 5D array of shape (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. @@ -785,11 +789,11 @@ class AvgPool1D(_Pooling): When True, will use ceil instead of floor to compute the output shape. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 3D array of shape (batch_size, channels, width) if `layout` is `NCW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 3D array of shape (batch_size, channels, out_width) if `layout` is `NCW`. @@ -831,11 +835,11 @@ class AvgPool2D(_Pooling): When True, will use ceil instead of floor to compute the output shape. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. @@ -879,11 +883,11 @@ class AvgPool3D(_Pooling): When True, will use ceil instead of floor to compute the output shape. - Input Shape: + Input shape: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. - Output Shape: + Output shape: This depends on the `layout` parameter. Output is 5D array of shape (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 083db134a65b..af51f399c018 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -21,21 +21,14 @@ class DeferredInitializationError(MXNetError): class Parameter(object): """A Container holding parameters (weights) of `Block`s. - `Parameter` can be used with both `Symbol` and `NDArray` API. For `Symbol` API, - `Parameter.var()` will return a `Symbol` representing this parameter. It - can then be used for composing networks:: - x = mx.sym.Variable('data') - w = mx.nn.Parameter('fc_weight', init=mx.init.Xavier()) - b = mx.nn.Parameter('fc_bias', init=mx.init.Zero()) - out = mx.sym.FullyConnected(x, w.var(), b.var(), num_hidden=64) - - For `NDArray` API, `Parameter` must be initialized with `Parameter.init`. It - will then hold a copy of the the parameter on each `Context`. If `grad_req` is + `Parameter` holds a copy of the the parameter on each `Context` after + it is initialized with `Parameter.initialize(...)`. If `grad_req` is not `null`, it will also hold a gradient array on each `Context`:: + ctx = mx.gpu(0) x = mx.nd.zeros((16, 100), ctx=ctx) - w = mx.nn.Parameter('fc_weight', shape=(64, 100), init=mx.init.Xavier()) - b = mx.nn.Parameter('fc_bias', shape(64,), init=mx.init.Zero()) + w = mx.gluon.Parameter('fc_weight', shape=(64, 100), init=mx.init.Xavier()) + b = mx.gluon.Parameter('fc_bias', shape=(64,), init=mx.init.Zero()) w.initialize(ctx=ctx) b.initialize(ctx=ctx) out = mx.nd.FullyConnected(x, w.data(ctx), b.data(ctx), num_hidden=64) @@ -66,9 +59,10 @@ class Parameter(object): Weight decay multiplier (L2 regulerizer coefficient). Works similarly to lr_mult. init : Initializer, default None Initializer of this parameter. Will use the global initializer by default. + """ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, - lr_mult=1.0, wd_mult=1.0, init=None): + lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False): self.name = name self.shape = shape self.dtype = dtype @@ -76,13 +70,13 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, self.wd_mult = wd_mult self.grad_req = grad_req self.init = init + self.allow_deferred_init = allow_deferred_init self._var = None self._data = None self._grad = None self._defered_init = () - def initialize(self, init=None, ctx=None, default_init=initializer.Xavier(), - allow_deferring=True): + def initialize(self, init=None, ctx=None, default_init=initializer.Uniform()): """Intialize parameter and gradient arrays. Only used for `NDArray` API. Parameters @@ -97,21 +91,41 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Xavier(), their values consistent when updating. Normally nn.Trainer does this for you. default_init : Initializer Default initializer is used when both `init` and `Parameter.init` are None. + + Examples + -------- + >>> weight = mx.gluon.Parameter('weight', shape=(2, 2)) + >>> weight.initialize(ctx=mx.cpu(0)) + >>> weight.data() + [[-0.01068833 0.01729892] + [ 0.02042518 -0.01618656]] + + >>> weight.grad() + [[ 0. 0.] + [ 0. 0.]] + + >>> weight.initialize(ctx=[mx.gpu(0), mx.gpu(1)]) + >>> weight.data(mx.gpu(0)) + [[-0.00873779 -0.02834515] + [ 0.05484822 -0.06206018]] + + >>> weight.data(mx.gpu(1)) + [[-0.00873779 -0.02834515] + [ 0.05484822 -0.06206018]] + """ if ctx is None: ctx = [context.current_context()] if isinstance(ctx, Context): ctx = [ctx] - - if self.shape is None or np.prod(self.shape) <= 0: - if allow_deferring: + if init is None: + init = default_init if self.init is None else self.init + if not self.shape or np.prod(self.shape) <= 0: + if self.allow_deferred_init: self._defered_init = (init, ctx, default_init) return raise ValueError("Cannot initialize Parameter %s because it has " \ - "invalid shape: %s. Please specify in_units, " \ - "in_channels, etc for `Block`s or " \ - "set allow_deferring to True to defer initialization " \ - "to first forward pass."%(self.name, str(self.shape))) + "invalid shape: %s."%(self.name, str(self.shape))) self._defered_init = (init, ctx, default_init) self._finish_deferred_init() @@ -161,8 +175,6 @@ def _finish_deferred_init(self): with autograd.pause(): data = ndarray.zeros(shape=self.shape, dtype=self.dtype, ctx=context.cpu()) - if init is None: - init = self.init initializer.create(default_init)( initializer.InitDesc(self.name, {'__init__': init}), data) @@ -222,7 +234,11 @@ def data(self, ctx=None): NDArray on ctx """ if ctx is None: - ctx = context.current_context() + list_ctx = self.list_ctx() + if len(list_ctx) == 1: + ctx = list_ctx[0] + else: + ctx = context.current_context() self._check_initialized(ctx) return self._data[ctx] @@ -241,7 +257,11 @@ def grad(self, ctx=None): Desired context. """ if ctx is None: - ctx = context.current_context() + list_ctx = self.list_ctx() + if len(list_ctx) == 1: + ctx = list_ctx[0] + else: + ctx = context.current_context() self._check_initialized(ctx) if self._grad is None: raise RuntimeError( @@ -371,7 +391,7 @@ def update(self, other): else: self._params[k] = v - def initialize(self, init=initializer.Xavier(), ctx=None, verbose=False): + def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False): """Intialize all Parameters manage by this dictionary to be used for `NDArray` API. Has no effect when using `Symbol` API. diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index 986d3cce363a..7333892da4f5 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -237,16 +237,16 @@ def forward(self, inputs, states): return super(RecurrentCell, self).forward(inputs, states) -class HRecurrentCell(RecurrentCell, HybridBlock): - """HRecurrentCell supports both Symbol and NDArray forwarding.""" +class HybridRecurrentCell(RecurrentCell, HybridBlock): + """HybridRecurrentCell supports hybridize.""" def __init__(self, prefix=None, params=None): - super(HRecurrentCell, self).__init__(prefix=prefix, params=params) + super(HybridRecurrentCell, self).__init__(prefix=prefix, params=params) def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError -class RNNCell(HRecurrentCell): +class RNNCell(HybridRecurrentCell): """Simple recurrent neural network cell. Parameters @@ -274,20 +274,24 @@ class RNNCell(HRecurrentCell): """ def __init__(self, hidden_size, activation='tanh', i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer=None, h2h_bias_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', input_size=0, prefix=None, params=None): super(RNNCell, self).__init__(prefix=prefix, params=params) self._hidden_size = hidden_size self._activation = activation self._input_size = input_size self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size), - init=i2h_weight_initializer) + init=i2h_weight_initializer, + allow_deferred_init=True) self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size), - init=h2h_weight_initializer) + init=h2h_weight_initializer, + allow_deferred_init=True) self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,), - init=i2h_bias_initializer) + init=i2h_bias_initializer, + allow_deferred_init=True) self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,), - init=h2h_bias_initializer) + init=h2h_bias_initializer, + allow_deferred_init=True) def state_info(self, batch_size=0): return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}] @@ -310,7 +314,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, return output, [output] -class LSTMCell(HRecurrentCell): +class LSTMCell(HybridRecurrentCell): """Long-Short Term Memory (LSTM) network cell. Parameters @@ -338,20 +342,24 @@ class LSTMCell(HRecurrentCell): """ def __init__(self, hidden_size, i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer='lstmbias', h2h_bias_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', input_size=0, prefix=None, params=None): super(LSTMCell, self).__init__(prefix=prefix, params=params) self._hidden_size = hidden_size self._input_size = input_size self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size), - init=i2h_weight_initializer) + init=i2h_weight_initializer, + allow_deferred_init=True) self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size), - init=h2h_weight_initializer) + init=h2h_weight_initializer, + allow_deferred_init=True) self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,), - init=i2h_bias_initializer) + init=i2h_bias_initializer, + allow_deferred_init=True) self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,), - init=h2h_bias_initializer) + init=h2h_bias_initializer, + allow_deferred_init=True) def state_info(self, batch_size=0): return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}, @@ -388,7 +396,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, return next_h, [next_h, next_c] -class GRUCell(HRecurrentCell): +class GRUCell(HybridRecurrentCell): """Gated Rectified Unit (GRU) network cell. Note: this is an implementation of the cuDNN version of GRUs (slight modification compared to Cho et al. 2014). @@ -416,18 +424,22 @@ class GRUCell(HRecurrentCell): """ def __init__(self, hidden_size, i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer=None, h2h_bias_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', input_size=0, prefix=None, params=None): super(GRUCell, self).__init__(prefix=prefix, params=params) self._hidden_size = hidden_size self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size), - init=i2h_weight_initializer) + init=i2h_weight_initializer, + allow_deferred_init=True) self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size), - init=h2h_weight_initializer) + init=h2h_weight_initializer, + allow_deferred_init=True) self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,), - init=i2h_bias_initializer) + init=i2h_bias_initializer, + allow_deferred_init=True) self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,), - init=h2h_bias_initializer) + init=h2h_bias_initializer, + allow_deferred_init=True) def state_info(self, batch_size=0): return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}] @@ -527,7 +539,7 @@ def hybrid_forward(self, *args, **kwargs): raise NotImplementedError -class DropoutCell(HRecurrentCell): +class DropoutCell(HybridRecurrentCell): """Apply dropout on input. Parameters @@ -564,7 +576,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N merge_outputs=merge_outputs) -class ModifierCell(HRecurrentCell): +class ModifierCell(HybridRecurrentCell): """Base class for modifier cells. A modifier cell takes a base cell, apply modifications on it (e.g. Zoneout), and returns a new cell. @@ -673,7 +685,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N return outputs, states -class BidirectionalCell(HRecurrentCell): +class BidirectionalCell(HybridRecurrentCell): """Bidirectional RNN cell. Parameters diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 166bbc4bb63b..8a2309841e12 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -44,16 +44,20 @@ def __init__(self, hidden_size, num_layers, layout, for j in (['l', 'r'] if self._dir == 2 else ['l']): self.i2h_weight.append( self.params.get('%s%d_i2h_weight'%(j, i), shape=(ng*nh, ni), - init=i2h_weight_initializer)) + init=i2h_weight_initializer, + allow_deferred_init=True)) self.h2h_weight.append( self.params.get('%s%d_h2h_weight'%(j, i), shape=(ng*nh, nh), - init=h2h_weight_initializer)) + init=h2h_weight_initializer, + allow_deferred_init=True)) self.i2h_bias.append( self.params.get('%s%d_i2h_bias'%(j, i), shape=(ng*nh,), - init=i2h_bias_initializer)) + init=i2h_bias_initializer, + allow_deferred_init=True)) self.h2h_bias.append( self.params.get('%s%d_h2h_bias'%(j, i), shape=(ng*nh,), - init=h2h_bias_initializer)) + init=h2h_bias_initializer, + allow_deferred_init=True)) ni = nh * self._dir self._unfused = self._unfuse() @@ -133,6 +137,14 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): return states def forward(self, inputs, states): + if isinstance(states, ndarray.NDArray): + states = [states] + batch_size = states[0].shape[self._layout.find('N')] + for state, info in zip(states, self.state_info(batch_size)): + if state.shape != info['shape']: + raise ValueError( + "Invalid recurrent state shape. Expecting %s, got %s."%( + str(info['shape']), str(state.shape))) if self._input_size == 0: for i in range(self._dir): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) @@ -226,17 +238,36 @@ class RNN(_RNNLayer): params : ParameterDict or None Shared Parameters for this `Block`. + + Input shapes: + The input shape depends on `layout`. For `layout='TNC'`, the + input has shape `(sequence_length, batch_size, input_size)` + + + Output shape: + The output shape depends on `layout`. For `layout='TNC'`, the + output has shape `(sequence_length, batch_size, num_hidden)`. + If `bidirectional` is True, output shape will instead be + `(sequence_length, batch_size, 2*num_hidden)` + + Recurrent state shape: + The recurrent state's shape is `(num_layers, batch_size, num_hidden)`. + If `bidirectional` is True, state shape will instead be + `(num_layers, batch_size, 2*num_hidden)` + + Examples -------- - >>> rnn = nn.RNN(100, 3) + >>> layer = mx.gluon.rnn.RNN(100, 3) + >>> layer.initialize() >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) - >>> h0 = mx.nd.random_uniform(shape=(2, 3, 100)) - >>> output, hn = rnn(input, h0) + >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) + >>> output, hn = layer(input, h0) """ def __init__(self, hidden_size, num_layers=1, activation='relu', layout='TNC', dropout=0, bidirectional=False, i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer=None, h2h_bias_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', input_size=0, **kwargs): super(RNN, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, @@ -305,18 +336,37 @@ class LSTM(_RNNLayer): params : ParameterDict or None Shared Parameters for this `Block`. + + Input shapes: + The input shape depends on `layout`. For `layout='TNC'`, the + input has shape `(sequence_length, batch_size, input_size)` + + Output shape: + The output shape depends on `layout`. For `layout='TNC'`, the + output has shape `(sequence_length, batch_size, num_hidden)`. + If `bidirectional` is True, output shape will instead be + `(sequence_length, batch_size, 2*num_hidden)` + + Recurrent state shape: + The recurrent state is a list of two NDArrays. Both has shape + `(num_layers, batch_size, num_hidden)`. + If `bidirectional` is True, state shape will instead be + `(num_layers, batch_size, 2*num_hidden)` + + Examples -------- - >>> rnn = nn.LSTM(100, 3) + >>> layer = mx.gluon.rnn.LSTM(100, 3) + >>> layer.initialize() >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) - >>> h0 = mx.nd.random_uniform(shape=(2, 3, 100)) - >>> c0 = mx.nd.random_uniform(shape=(2, 3, 100)) - >>> output, hn = rnn(input, (h0, c0)) + >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) + >>> c0 = mx.nd.random_uniform(shape=(3, 3, 100)) + >>> output, hn = layer(input, [h0, c0]) """ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout=0, bidirectional=False, input_size=0, i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer='lstmbias', h2h_bias_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', **kwargs): super(LSTM, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, @@ -381,17 +431,35 @@ class GRU(_RNNLayer): params : ParameterDict or None Shared Parameters for this `Block`. + + Input shapes: + The input shape depends on `layout`. For `layout='TNC'`, the + input has shape `(sequence_length, batch_size, input_size)` + + Output shape: + The output shape depends on `layout`. For `layout='TNC'`, the + output has shape `(sequence_length, batch_size, num_hidden)`. + If `bidirectional` is True, output shape will instead be + `(sequence_length, batch_size, 2*num_hidden)` + + Recurrent state shape: + The recurrent state's shape is `(num_layers, batch_size, num_hidden)`. + If `bidirectional` is True, state shape will instead be + `(num_layers, batch_size, 2*num_hidden)` + + Examples -------- - >>> rnn = nn.GRU(100, 2) + >>> layer = mx.gluon.rnn.GRU(100, 3) + >>> layer.initialize() >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) - >>> h0 = mx.nd.random_uniform(shape=(2, 3, 100)) - >>> output, hn = rnn(input, h0) + >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) + >>> output, hn = layer(input, h0) """ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout=0, bidirectional=False, input_size=0, i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer=None, h2h_bias_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', **kwargs): super(GRU, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 5d79f7342b1a..8f20bd1e698b 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -4,6 +4,7 @@ from .. import optimizer as opt from ..model import _create_kvstore +from .parameter import ParameterDict, Parameter class Trainer(object): """Applies an Optimizer on a set of Parameters. Trainer should @@ -16,15 +17,28 @@ class Trainer(object): optimizer : str or Optimizer The optimizer to use. optimizer_params : dict - key-word arguments to be passed to Optimizer.create_optimizer. For example, - {'learning_rate': 0.1} + key-word arguments to be passed to optimizer constructor. For example, + `{'learning_rate': 0.1}` kvstore : str or KVStore kvstore type for multi-gpu and distributed training. """ def __init__(self, params, optimizer, optimizer_params, kvstore='device'): - self._params = [param for param in params.values() if param.grad_req != 'null'] - self._scale = optimizer_params.get('rescale_grad', 1.0) + if isinstance(params, (dict, ParameterDict)): + params = list(params.values()) + if not isinstance(params, (list, tuple)): + raise ValueError( + "First argument must be a list or dict of Parameters, " \ + "got %s."%(type(params))) + self._params = [] + for param in params: + if not isinstance(param, Parameter): + raise ValueError( + "First argument must be a list or dict of Parameters, " \ + "got list of %s."%(type(param))) + if param.grad_req != 'null': + self._params.append(param) + self._scale = optimizer_params.get('rescale_grad', 1.0) self._contexts = self._check_contexts() self._init_optimizer(optimizer, optimizer_params) self._kv_initialized = False diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 29f0f769ed63..537228a495fc 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -122,9 +122,9 @@ class NDArray(NDArrayBase): def __repr__(self): """Returns a string representation of the array.""" shape_info = 'x'.join(['%d' % x for x in self.shape]) - return '%s\n<%s %s @%s>' % (str(self.asnumpy()), - self.__class__.__name__, - shape_info, self.context) + return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()), + self.__class__.__name__, + shape_info, self.context) def __add__(self, other): """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """ @@ -370,13 +370,18 @@ def __setitem__(self, key, value): assert slice_i < my_shape[i] begin[i] = slice_i end[i] = slice_i + 1 - if isinstance(slice_i, py_slice): + elif isinstance(slice_i, py_slice): # only support continuous slicing - assert slice_i.step is None + assert slice_i.step is None, \ + "NDArray only supports continuous slicing." begin[i] = slice_i.start or 0 end[i] = slice_i.stop or my_shape[i] assert begin[i] < end[i] assert end[i] <= my_shape[i] + else: + raise ValueError( + "NDArray does not support slicing with %s."%( + str(slice_i))) begin = tuple(begin) end = tuple(end) if isinstance(value, NDArray): @@ -434,8 +439,32 @@ def __getitem__(self, key): else: return self if isinstance(key, tuple): - raise ValueError('Multi-dimension indexing is not supported') - + shape = self.shape + oshape = [] + begin = [] + end = [] + assert len(shape) >= len(key), \ + "Slicing dimensions exceeds array dimensions, %d vs %d"%( + len(key), len(shape)) + i = -1 + for i, slice_i in enumerate(key): + if isinstance(slice_i, int): + begin.append(slice_i) + end.append(slice_i+1) + elif isinstance(slice_i, py_slice): + if slice_i.step is not None: + raise ValueError("NDArray only supports continuous slicing.") + begin.append(0 if slice_i.start is None else slice_i.start) + end.append(shape[i] if slice_i.stop is None else slice_i.stop) + oshape.append(end[i] - begin[i]) + else: + raise ValueError( + "NDArray does not support slicing with %s."%( + str(slice_i))) + oshape.extend(shape[i+1:]) + if len(oshape) == 0: + oshape.append(1) + return slice(self, begin, end).reshape(oshape) def _sync_copyfrom(self, source_array): """Performs a synchronized copy from the `source_array` to the current array. diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py index d57ee717fcf6..884775d26317 100644 --- a/python/mxnet/operator.py +++ b/python/mxnet/operator.py @@ -11,7 +11,7 @@ from .base import _LIB, check_call from .base import c_array, c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str -from . import symbol +from . import symbol, context from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP c_int_p = POINTER(c_int) @@ -448,7 +448,7 @@ class CustomOpProp(object): The default declare_backward_dependency function. Use this value to determine whether this operator needs gradient input. """ - def __init__(self, need_top_grad=False): + def __init__(self, need_top_grad=True): self.need_top_grad_ = need_top_grad def infer_shape(self, in_shape): @@ -734,6 +734,9 @@ def declare_backward_dependency_entry(out_grad, in_data, out_data, num_dep, deps def create_operator_entry(ctx, num_inputs, shapes, ndims, dtypes, ret, _): """C Callback for CustomOpProp::CreateOperator""" try: + ctx = py_str(ctx) + sep = ctx.find('(') + ctx = context.Context(ctx[:sep], int(ctx[sep+1:-1])) ndims = [ndims[i] for i in range(num_inputs)] shapes = [[shapes[i][j] for j in range(ndims[i])] for i in range(num_inputs)] dtypes = [dtypes[i] for i in range(num_inputs)] @@ -753,9 +756,10 @@ def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _): NDArrayHandle), writable=False)) reqs = [req_enum[reqs[i]] for i in range(len(tensors[1]))] - op.forward(is_train=is_train, req=reqs, - in_data=tensors[0], out_data=tensors[1], - aux=tensors[4]) + with ctx: + op.forward(is_train=is_train, req=reqs, + in_data=tensors[0], out_data=tensors[1], + aux=tensors[4]) except Exception: print('Error in CustomOp.forward: %s' % traceback.format_exc()) return False @@ -776,10 +780,11 @@ def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _): NDArrayHandle), writable=False)) reqs = [req_enum[reqs[i]] for i in range(len(tensors[2]))] - op.backward(req=reqs, - in_data=tensors[0], out_data=tensors[1], - in_grad=tensors[2], out_grad=tensors[3], - aux=tensors[4]) + with ctx: + op.backward(req=reqs, + in_data=tensors[0], out_data=tensors[1], + in_grad=tensors[2], out_grad=tensors[3], + aux=tensors[4]) except Exception: print('Error in CustomOp.backward: %s' % traceback.format_exc()) return False diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 63ed6c482c6a..818f263cb3b7 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -110,9 +110,8 @@ void SetContext(Context* p_ctx, CHECK_EQ(ndinputs[i].ctx().dev_mask(), ctx.dev_mask()) << "All inputs must live on the same context. " << "But the first argument is on " - << (ctx.dev_mask() == gpu::kDevMask ? "GPU" : "CPU") - << " while the " << i+1 << "-th argument is on " - << (ndinputs[i].ctx().dev_mask() == gpu::kDevMask ? "GPU" : "CPU"); + << ctx << " while the " << i+1 << "-th argument is on " + << ndinputs[i].ctx(); } } else if (ndoutputs.size() && !ndoutputs[0].is_none()) { ctx = ndoutputs[0].ctx(); diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 1854bb7f05d0..ee420635f824 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -222,17 +222,13 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx, } } - std::string str_ctx; - if (ctx.dev_mask() == cpu::kDevMask) { - str_ctx = "cpu"; - } else { - str_ctx = "gpu"; - } + std::ostringstream os; + os << ctx; MXCallbackList *op_info = new MXCallbackList; CHECK(reinterpret_cast( params.info->callbacks[kCustomOpPropCreateOperator])( - str_ctx.c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(), + os.str().c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(), op_info, params.info->contexts[kCustomOpPropCreateOperator])); CustomParam state = params; diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 88e9d3095e24..10273c0a8c68 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -722,9 +722,11 @@ inline TShape GetSliceShape(const SliceParam& param, const TShape& dshape) { << "Slicing axis exceeds data dimensions"; CHECK_LE(param.end.ndim(), dshape.ndim()) << "Slicing axis exceeds data dimensions"; + CHECK_EQ(param.begin.ndim(), param.end.ndim()) + << "begin and end must have the same length"; - TShape oshape(dshape.ndim()); - for (index_t i = 0; i < dshape.ndim(); ++i) { + TShape oshape = dshape; + for (index_t i = 0; i < param.begin.ndim(); ++i) { int s = 0, e = dshape[i]; if (e != 0) { if (param.begin[i]) { diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 8c58d3b47a69..8b7f8d6d7bf3 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -249,6 +249,14 @@ def test_ndarray_slice(): A[3:8] = A2[3:8] assert same(A[3:8].asnumpy(), A2[3:8]) + shape = (3,4,5,6,7) + A = mx.nd.random_uniform(shape=shape) + A2 = A.asnumpy() + + assert same(A[1,3:4,:,1:5].asnumpy(), A2[1,3:4,:,1:5]) + + assert A[1,2,3,4,5].asscalar() == A2[1,2,3,4,5] + def test_ndarray_crop(): # get crop @@ -653,6 +661,7 @@ def test_output(): mx.nd.full(shape, 2, out=out) assert_almost_equal(out.asnumpy(), ones.asnumpy() * 2) + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_nn.py b/tests/python/unittest/test_nn.py index 6dc38b4b0ce9..cc1b2dd48553 100644 --- a/tests/python/unittest/test_nn.py +++ b/tests/python/unittest/test_nn.py @@ -54,7 +54,7 @@ def test_basic(): assert len(y.list_arguments()) == 7 # ndarray - model.collect_params().initialize() + model.collect_params().initialize(mx.init.Xavier(magnitude=2.24)) x = model(mx.nd.zeros((32, 10))) assert x.shape == (32, 32) x.wait_to_read() @@ -95,7 +95,7 @@ def test_conv(): layers3d = [ - nn.Conv3D(16, (1, 8, 4), in_channels=4), + nn.Conv3D(16, (1, 8, 4), in_channels=4, activation='relu'), nn.Conv3D(16, (5, 4, 3), in_channels=4), nn.Conv3D(16, (3, 3, 3), groups=2, in_channels=4), nn.Conv3D(16, 4, strides=4, in_channels=4), @@ -263,6 +263,16 @@ def test_split_data(): assert False, "Should have failed" +def test_flatten(): + flatten = nn.Flatten() + x = mx.nd.zeros((3,4,5,6)) + assert flatten(x).shape == (3, 4*5*6) + x = mx.nd.zeros((3,6)) + assert flatten(x).shape == (3, 6) + x = mx.nd.zeros((3,)) + assert flatten(x).shape == (3, 1) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 14593f6ce5b2..e13c3c07f2fd 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1972,7 +1972,7 @@ def check_instance_norm_with_shape(shape, xpu): exec1 = Y.bind(xpu, args = {'X':x, 'G':gamma, 'B':beta}) exec1.forward(is_train=False) out = exec1.outputs[0].asnumpy() - assert_almost_equal(out, np_out, rtol=1e-4) + assert_almost_equal(out, np_out, rtol=1e-4, atol=1e-5) check_numeric_gradient(Y, {'X':x.asnumpy(), 'G':gamma.asnumpy(), 'B':beta.asnumpy()}, numeric_eps=1e-2, rtol=1e-2, atol=1e-2) @@ -2010,7 +2010,7 @@ def check_l2_normalization(in_shape, mode, ctx=default_context(), norm_eps=1e-10 exe = out.simple_bind(ctx=ctx, data=in_data.shape) output = exe.forward(is_train=True, data=in_data) # compare numpy + mxnet - assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-5) + assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-4, atol=1e-5) # check gradient check_numeric_gradient(out, [in_data], numeric_eps=1e-3, rtol=1e-2, atol=1e-3)