Skip to content

Commit

Permalink
fix #3542: move layout_mapper argument out of existing API (#3565)
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid authored Oct 19, 2016
1 parent f128e4e commit fbef78d
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 39 deletions.
12 changes: 5 additions & 7 deletions python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(self, logger=logging):
self.params_initialized = False
self.optimizer_initialized = False
self._symbol = None
self.layout_mapper = None

################################################################################
# High Level API
Expand Down Expand Up @@ -329,9 +330,10 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
"""
assert num_epoch is not None, 'please specify number of epochs'

layout_mapper = train_data.layout_mapper if hasattr(train_data, 'layout_mapper') else None
if hasattr(train_data, 'layout_mapper'):
self.layout_mapper = train_data.layout_mapper
self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
for_training=True, force_rebind=force_rebind, layout_mapper=layout_mapper)
for_training=True, force_rebind=force_rebind)
if monitor is not None:
self.install_monitor(monitor)
self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
Expand Down Expand Up @@ -604,8 +606,7 @@ def update_metric(self, eval_metric, labels):
# module setup
################################################################################
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
layout_mapper=None):
inputs_need_grad=False, force_rebind=False, shared_module=None):
"""Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Expand All @@ -628,9 +629,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
Default is `None`. This is used in bucketing. When not `None`, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
layout_mapper: LayoutMapper
Default None. A helper that decide the layout of data, label and outputs
(time-major? batch-major?).
"""
raise NotImplementedError()

Expand Down
16 changes: 5 additions & 11 deletions python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(self, sym_gen, default_bucket_key=None,

self._buckets = {}
self._curr_module = None
self._layout_mapper = None

def _reset_bind(self):
"""Internal utility function to reset binding."""
Expand Down Expand Up @@ -136,8 +135,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non
self.params_initialized = True

def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
layout_mapper=None):
inputs_need_grad=False, force_rebind=False, shared_module=None):
"""Binding for a `BucketingModule` means setting up the buckets and bind the
executor for the default bucket key. Executors corresponding to other keys are
binded afterwards with `switch_bucket`.
Expand All @@ -156,9 +154,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
Default is `False`.
shared_module : BucketingModule
Default is `None`. This value is currently not used.
layout_mapper: LayoutMapper
Default None. A helper that decide the layout of data, label and outputs
(time-major? batch-major?).
"""
# in case we already initialized params, keep it
if self.params_initialized:
Expand All @@ -182,17 +177,16 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
symbol, data_names, label_names = self._sym_gen(self._default_bucket_key)
module = Module(symbol, data_names, label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list)
module.layout_mapper = self.layout_mapper
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None, layout_mapper=layout_mapper)
force_rebind=False, shared_module=None)
self._curr_module = module
self._buckets[self._default_bucket_key] = module

# copy back saved params, if already initialized
if self.params_initialized:
self.set_params(arg_params, aux_params)

self._layout_mapper = layout_mapper

def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
"""Switch to a different bucket. This will change `self.curr_module`.
Expand All @@ -211,10 +205,10 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
module = Module(symbol, data_names, label_names,
logger=self.logger, context=self._context,
work_load_list=self._work_load_list)
module.layout_mapper = self.layout_mapper
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._curr_module,
layout_mapper=self._layout_mapper)
force_rebind=False, shared_module=self._curr_module)
self._buckets[bucket_key] = module

self._curr_module = self._buckets[bucket_key]
Expand Down
8 changes: 2 additions & 6 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ def _impl(name, arr, cache):
self._exec_group.set_params(self._arg_params, self._aux_params)

def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
layout_mapper=None):
inputs_need_grad=False, force_rebind=False, shared_module=None):
"""Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Expand All @@ -223,9 +222,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
Default is `None`. This is used in bucketing. When not `None`, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
layout_mapper: LayoutMapper
Default None. A helper that decide the layout of data, label and outputs
(time-major? batch-major?).
"""
# force rebinding is typically used when one want to switch from
# training to prediction phase.
Expand Down Expand Up @@ -264,7 +260,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
for_training, inputs_need_grad,
shared_group, logger=self.logger,
fixed_param_names=self._fixed_param_names,
layout_mapper=layout_mapper)
layout_mapper=self.layout_mapper)
if shared_module is not None:
self.params_initialized = True
self._arg_params = shared_module._arg_params
Expand Down
10 changes: 1 addition & 9 deletions python/mxnet/module/python_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def __init__(self, data_names, label_names, output_names, logger=logging):
self._label_shapes = None
self._output_shapes = None

self.layout_mapper = None

################################################################################
# Symbol information
################################################################################
Expand Down Expand Up @@ -141,8 +139,7 @@ def update_metric(self, eval_metric, labels):
# module setup
################################################################################
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
layout_mapper=None):
inputs_need_grad=False, force_rebind=False, shared_module=None):
"""Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Expand All @@ -165,16 +162,11 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
Default is `None`. This is used in bucketing. When not `None`, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
layout_mapper: LayoutMapper
Default None. A helper that decide the layout of data, label and outputs
(time-major? batch-major?).
"""
if self.binded and not force_rebind:
self.logger.warning('Already binded, ignoring bind()')
return

self.layout_mapper = layout_mapper

self.for_training = for_training
self.inputs_need_grad = inputs_need_grad

Expand Down
8 changes: 2 additions & 6 deletions python/mxnet/module/sequential_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ def _check_name(known_names, new_names, modules, i):
self.params_initialized = True

def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
layout_mapper=None):
inputs_need_grad=False, force_rebind=False, shared_module=None):
"""Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Expand All @@ -203,9 +202,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
binded. But with this `True`, the executors will be forced to rebind.
shared_module : Module
Default is `None`. Currently shared module is not supported for `SequentialModule`.
layout_mapper: LayoutMapper
Default None. A helper that decide the layout of data, label and outputs
(time-major? batch-major?).
"""
if self.binded and not force_rebind:
self.logger.warning('Already binded, ignoring bind()')
Expand Down Expand Up @@ -243,7 +239,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,

module.bind(data_shapes=my_data_shapes, label_shapes=my_label_shapes,
for_training=for_training, inputs_need_grad=my_inputs_need_grad,
force_rebind=force_rebind, shared_module=None, layout_mapper=layout_mapper)
force_rebind=force_rebind, shared_module=None)

# the output of the previous module is the data of the next module
my_data_shapes = module.output_shapes
Expand Down

0 comments on commit fbef78d

Please sign in to comment.