Skip to content

Commit

Permalink
Some fixes for gluon (apache#7013)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* more fixes & add ndarray slicing

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
piiswrong authored Jul 13, 2017
1 parent 19d2a3d commit 255a794
Show file tree
Hide file tree
Showing 17 changed files with 374 additions and 156 deletions.
2 changes: 1 addition & 1 deletion docs/api/python/gluon.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ in Python and then deploy with symbolic graph in C++ and Scala.
.. automethod:: __call__
.. autoclass:: mxnet.gluon.nn.Sequential
:members:
.. autoclass:: mxnet.gluon.nn.HSequential
.. autoclass:: mxnet.gluon.nn.HybridSequential
:members:
```

Expand Down
8 changes: 4 additions & 4 deletions example/gluon/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs):
self.bn0 = nn.BatchNorm(in_channels=filters[0])
self.pool0 = nn.MaxPool2D(3, 2, 1)

self.body = nn.HSequential()
self.body = nn.HybridSequential()
in_channels = filters[0]
for i in range(len(layers)):
stride = 1 if i == 0 else 2
Expand All @@ -146,7 +146,7 @@ def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs):
self.dense1 = nn.Dense(classes, in_units=filters[-1])

def _make_layer(self, block, layers, filters, stride, in_channels=0):
layer = nn.HSequential()
layer = nn.HybridSequential()
layer.add(block(filters, stride, True, in_channels=in_channels))
for i in range(layers-1):
layer.add(block(filters, 1, False, in_channels=filters))
Expand Down Expand Up @@ -248,7 +248,7 @@ def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs):
self.bn0 = nn.BatchNorm(in_channels=filters[0])
self.pool0 = nn.MaxPool2D(3, 2, 1)

self.body = nn.HSequential()
self.body = nn.HybridSequential()
in_channels = filters[0]
for i in range(len(layers)):
stride = 1 if i == 0 else 2
Expand All @@ -261,7 +261,7 @@ def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs):
self.dense1 = nn.Dense(classes, in_units=in_channels)

def _make_layer(self, block, layers, filters, stride, in_channels=0):
layer = nn.HSequential()
layer = nn.HybridSequential()
layer.add(block(filters, stride, True, in_channels=in_channels))
for i in range(layers-1):
layer.add(block(filters, 1, False, in_channels=filters))
Expand Down
20 changes: 16 additions & 4 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# pylint: disable= arguments-differ
"""Base container class for all neural network models."""

from .. import symbol, ndarray
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
from ..ndarray import NDArray
from .. import name as _name
Expand Down Expand Up @@ -99,7 +99,7 @@ class Block(object):
class Model(Block):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
super(Model, self).__init__(**kwargs)
# use name_scope to give child Blocks appropriate names.
# It also allows sharing Parameters between Blocks recursively.
with self.name_scope():
Expand All @@ -110,6 +110,11 @@ def forward(self, x):
x = F.relu(self.dense0(x))
return F.relu(self.dense1(x))
model = Model()
model.initialize(ctx=mx.cpu(0))
model(F.zeros((10, 10), ctx=mx.cpu(0)))
Child `Block`s assigned this way will be registered and `collect_params`
will collect their Parameters recursively.
Expand All @@ -124,7 +129,7 @@ def forward(self, x):
if you want `dense1` to share `dense0`'s weights, you can do::
dense0 = nn.Dense(20)
dense1 = nn.Dense(20, params=dense1.collect_params())
dense1 = nn.Dense(20, params=dense0.collect_params())
"""
def __init__(self, prefix=None, params=None):
self._prefix, self._params = _BlockScope.create(prefix, params, self._alias())
Expand Down Expand Up @@ -181,6 +186,13 @@ def register_child(self, block):
attributes will be registered automatically."""
self._children.append(block)

def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False):
"""Initialize `Parameter`s of this Block and its children.
Equivalent to `block.collect_params().initialize(...)`
"""
self.collect_params().initialize(init, ctx, verbose)

def hybridize(self, active=True):
"""Activates or deactivates `HybridBlock`s recursively. Has no effect on
non-hybrid children.
Expand Down Expand Up @@ -250,7 +262,7 @@ def register_child(self, block):
if isinstance(block, Sequential):
raise ValueError(
"Children of HybridBlock must also be HybridBlock. " \
"Please use HSequential instead of Sequential.")
"Please use HybridSequential instead of Sequential.")
raise ValueError(
"Children of HybridBlock must also be HybridBlock, " \
"but %s has type %s."%(str(block), str(type(block))))
Expand Down
58 changes: 48 additions & 10 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def forward(self, x):
return x


class HSequential(HybridBlock):
class HybridSequential(HybridBlock):
"""Stack `HybridBlock`s sequentially.
Example::
Expand All @@ -41,7 +41,7 @@ class HSequential(HybridBlock):
net.add(Dense(20))
"""
def __init__(self, prefix=None, params=None):
super(HSequential, self).__init__(prefix=prefix, params=params)
super(HybridSequential, self).__init__(prefix=prefix, params=params)

def add(self, block):
"""Add block on top of the stack."""
Expand Down Expand Up @@ -97,16 +97,18 @@ class Dense(HybridBlock):
the output would have shape `(batch_size, units)`.
"""
def __init__(self, units, activation=None, use_bias=True,
weight_initializer=None, bias_initializer=None,
weight_initializer=None, bias_initializer='zeros',
in_units=0, **kwargs):
super(Dense, self).__init__(**kwargs)
with self.name_scope():
self._units = units
self.weight = self.params.get('weight', shape=(units, in_units),
init=weight_initializer)
init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=(units,),
init=bias_initializer)
init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
if activation is not None:
Expand All @@ -133,6 +135,7 @@ class Activation(HybridBlock):
name of activation function to use.
See :func:`~mxnet.ndarray.Activation` for available choices.
Input shape:
Arbitrary.
Expand Down Expand Up @@ -210,6 +213,13 @@ class BatchNorm(HybridBlock):
Number of channels (feature maps) in input data. If not specified,
initialization will be defered to the first time `forward` is called
and `in_channels` will be inferred from the shape of input data.
Input shape:
Arbitrary.
Output shape:
Same shape as input.
"""
def __init__(self, axis=1, momentum=0.9, epsilon=1e-3, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
Expand All @@ -220,15 +230,19 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-3, center=True, scale=True,
'fix_gamma': not center}

self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer)
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer)
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)
self.running_mean = self.params.get('running_mean', grad_req='null',
shape=(in_channels,),
init=running_mean_initializer)
init=running_mean_initializer,
allow_deferred_init=True)
self.running_var = self.params.get('running_var', grad_req='null',
shape=(in_channels,),
init=running_variance_initializer)
init=running_variance_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
return F.BatchNorm(x, gamma, beta, running_mean, running_var, **self._kwargs)
Expand All @@ -246,6 +260,13 @@ class LeakyReLU(HybridBlock):
----------
alpha : float
slope coefficient for the negative half axis. Must be >= 0.
Input shape:
Arbitrary.
Output shape:
Same shape as input.
"""
def __init__(self, alpha, **kwargs):
super(LeakyReLU, self).__init__(**kwargs)
Expand Down Expand Up @@ -284,7 +305,24 @@ def __init__(self, input_dim, output_dim, dtype='float32',
self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
'dtype': dtype}
self.weight = self.params.get('weight', shape=(input_dim, output_dim),
init=weight_initializer)
init=weight_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, x, weight):
return F.Embedding(x, weight, **self._kwargs)


class Flatten(HybridBlock):
"""Flattens the input to two dimensional.
Input shape:
Arbitrary shape `(N, a, b, c, ...)`
Output shape:
2D tensor with shape: `(N, a*b*c...)`
"""
def __init__(self, **kwargs):
super(Flatten, self).__init__(**kwargs)

def hybrid_forward(self, F, x):
return x.reshape((0, -1))
Loading

0 comments on commit 255a794

Please sign in to comment.