diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index d6f7ce92..704000d6 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -117,6 +117,7 @@ def __init__(self, logger=logging): self.params_initialized = False self.optimizer_initialized = False self._symbol = None + self.layout_mapper = None ################################################################################ # High Level API @@ -329,9 +330,10 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', """ assert num_epoch is not None, 'please specify number of epochs' - layout_mapper = train_data.layout_mapper if hasattr(train_data, 'layout_mapper') else None + if hasattr(train_data, 'layout_mapper'): + self.layout_mapper = train_data.layout_mapper self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, - for_training=True, force_rebind=force_rebind, layout_mapper=layout_mapper) + for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, @@ -604,8 +606,7 @@ def update_metric(self, eval_metric, labels): # module setup ################################################################################ def bind(self, data_shapes, label_shapes=None, for_training=True, - inputs_need_grad=False, force_rebind=False, shared_module=None, - layout_mapper=None): + inputs_need_grad=False, force_rebind=False, shared_module=None): """Bind the symbols to construct executors. This is necessary before one can perform computation with the module. @@ -628,9 +629,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, Default is `None`. This is used in bucketing. When not `None`, the shared module essentially corresponds to a different bucket -- a module with different symbol but with the same sets of parameters (e.g. unrolled RNNs with different lengths). - layout_mapper: LayoutMapper - Default None. A helper that decide the layout of data, label and outputs - (time-major? batch-major?). """ raise NotImplementedError() diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 4a8986f3..8f204147 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -42,7 +42,6 @@ def __init__(self, sym_gen, default_bucket_key=None, self._buckets = {} self._curr_module = None - self._layout_mapper = None def _reset_bind(self): """Internal utility function to reset binding.""" @@ -136,8 +135,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non self.params_initialized = True def bind(self, data_shapes, label_shapes=None, for_training=True, - inputs_need_grad=False, force_rebind=False, shared_module=None, - layout_mapper=None): + inputs_need_grad=False, force_rebind=False, shared_module=None): """Binding for a `BucketingModule` means setting up the buckets and bind the executor for the default bucket key. Executors corresponding to other keys are binded afterwards with `switch_bucket`. @@ -156,9 +154,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, Default is `False`. shared_module : BucketingModule Default is `None`. This value is currently not used. - layout_mapper: LayoutMapper - Default None. A helper that decide the layout of data, label and outputs - (time-major? batch-major?). """ # in case we already initialized params, keep it if self.params_initialized: @@ -182,8 +177,9 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, symbol, data_names, label_names = self._sym_gen(self._default_bucket_key) module = Module(symbol, data_names, label_names, logger=self.logger, context=self._context, work_load_list=self._work_load_list) + module.layout_mapper = self.layout_mapper module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, - force_rebind=False, shared_module=None, layout_mapper=layout_mapper) + force_rebind=False, shared_module=None) self._curr_module = module self._buckets[self._default_bucket_key] = module @@ -191,8 +187,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, if self.params_initialized: self.set_params(arg_params, aux_params) - self._layout_mapper = layout_mapper - def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): """Switch to a different bucket. This will change `self.curr_module`. @@ -211,10 +205,10 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): module = Module(symbol, data_names, label_names, logger=self.logger, context=self._context, work_load_list=self._work_load_list) + module.layout_mapper = self.layout_mapper module.bind(data_shapes, label_shapes, self._curr_module.for_training, self._curr_module.inputs_need_grad, - force_rebind=False, shared_module=self._curr_module, - layout_mapper=self._layout_mapper) + force_rebind=False, shared_module=self._curr_module) self._buckets[bucket_key] = module self._curr_module = self._buckets[bucket_key] diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index bc053754..57a7ab64 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -199,8 +199,7 @@ def _impl(name, arr, cache): self._exec_group.set_params(self._arg_params, self._aux_params) def bind(self, data_shapes, label_shapes=None, for_training=True, - inputs_need_grad=False, force_rebind=False, shared_module=None, - layout_mapper=None): + inputs_need_grad=False, force_rebind=False, shared_module=None): """Bind the symbols to construct executors. This is necessary before one can perform computation with the module. @@ -223,9 +222,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, Default is `None`. This is used in bucketing. When not `None`, the shared module essentially corresponds to a different bucket -- a module with different symbol but with the same sets of parameters (e.g. unrolled RNNs with different lengths). - layout_mapper: LayoutMapper - Default None. A helper that decide the layout of data, label and outputs - (time-major? batch-major?). """ # force rebinding is typically used when one want to switch from # training to prediction phase. @@ -264,7 +260,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, for_training, inputs_need_grad, shared_group, logger=self.logger, fixed_param_names=self._fixed_param_names, - layout_mapper=layout_mapper) + layout_mapper=self.layout_mapper) if shared_module is not None: self.params_initialized = True self._arg_params = shared_module._arg_params diff --git a/python/mxnet/module/python_module.py b/python/mxnet/module/python_module.py index f3909c7e..141849f1 100644 --- a/python/mxnet/module/python_module.py +++ b/python/mxnet/module/python_module.py @@ -38,8 +38,6 @@ def __init__(self, data_names, label_names, output_names, logger=logging): self._label_shapes = None self._output_shapes = None - self.layout_mapper = None - ################################################################################ # Symbol information ################################################################################ @@ -141,8 +139,7 @@ def update_metric(self, eval_metric, labels): # module setup ################################################################################ def bind(self, data_shapes, label_shapes=None, for_training=True, - inputs_need_grad=False, force_rebind=False, shared_module=None, - layout_mapper=None): + inputs_need_grad=False, force_rebind=False, shared_module=None): """Bind the symbols to construct executors. This is necessary before one can perform computation with the module. @@ -165,16 +162,11 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, Default is `None`. This is used in bucketing. When not `None`, the shared module essentially corresponds to a different bucket -- a module with different symbol but with the same sets of parameters (e.g. unrolled RNNs with different lengths). - layout_mapper: LayoutMapper - Default None. A helper that decide the layout of data, label and outputs - (time-major? batch-major?). """ if self.binded and not force_rebind: self.logger.warning('Already binded, ignoring bind()') return - self.layout_mapper = layout_mapper - self.for_training = for_training self.inputs_need_grad = inputs_need_grad diff --git a/python/mxnet/module/sequential_module.py b/python/mxnet/module/sequential_module.py index 7acbe9d5..3e9ac3d4 100644 --- a/python/mxnet/module/sequential_module.py +++ b/python/mxnet/module/sequential_module.py @@ -181,8 +181,7 @@ def _check_name(known_names, new_names, modules, i): self.params_initialized = True def bind(self, data_shapes, label_shapes=None, for_training=True, - inputs_need_grad=False, force_rebind=False, shared_module=None, - layout_mapper=None): + inputs_need_grad=False, force_rebind=False, shared_module=None): """Bind the symbols to construct executors. This is necessary before one can perform computation with the module. @@ -203,9 +202,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, binded. But with this `True`, the executors will be forced to rebind. shared_module : Module Default is `None`. Currently shared module is not supported for `SequentialModule`. - layout_mapper: LayoutMapper - Default None. A helper that decide the layout of data, label and outputs - (time-major? batch-major?). """ if self.binded and not force_rebind: self.logger.warning('Already binded, ignoring bind()') @@ -243,7 +239,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, module.bind(data_shapes=my_data_shapes, label_shapes=my_label_shapes, for_training=for_training, inputs_need_grad=my_inputs_need_grad, - force_rebind=force_rebind, shared_module=None, layout_mapper=layout_mapper) + force_rebind=force_rebind, shared_module=None) # the output of the previous module is the data of the next module my_data_shapes = module.output_shapes