From 94a2c60e079505164a8928a5803ebe7106920e9f Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 29 Aug 2017 10:34:56 -0700 Subject: [PATCH] Refactor random linalg contrib namespaces (#7604) * Refactor namespaces contrib, linalg, random, and sparse for op registration Change examples in documentation Change namespace usage in examples Fix pylint Remove unused import Switch name and alias in linalg and random Change stype comparison from string to int for functions used internally Change documentation to use the right namespace Register ops under ndarray/op.py and symbol/op.py Remove unused import Change .cu op names * Add __all__ to ndarray and symbol modules * Revert "Add __all__ to ndarray and symbol modules" This reverts commit 8bc5de77bfdb40ff48dc570e2c6c49ec5d43ea64. * Add __all__ to ndarray and symbol modules --- benchmark/python/sparse/dot.py | 2 +- benchmark/python/sparse/sparse_op.py | 6 +- example/ctc/lstm.py | 2 +- example/gluon/dcgan.py | 2 +- example/gluon/lstm_crf.py | 6 +- example/rcnn/rcnn/symbol/symbol_resnet.py | 4 +- example/rcnn/rcnn/symbol/symbol_vgg.py | 6 +- example/ssd/symbol/common.py | 5 +- example/ssd/symbol/legacy_vgg16_ssd_300.py | 6 +- example/ssd/symbol/legacy_vgg16_ssd_512.py | 6 +- example/ssd/symbol/symbol_builder.py | 6 +- .../tools/caffe_converter/convert_symbol.py | 4 +- python/mxnet/_ctypes/ndarray.py | 23 +- python/mxnet/autograd.py | 3 +- python/mxnet/base.py | 83 ++++++++ python/mxnet/contrib/autograd.py | 3 +- python/mxnet/contrib/ndarray.py | 1 + python/mxnet/contrib/symbol.py | 1 + python/mxnet/gluon/block.py | 2 +- python/mxnet/gluon/rnn/rnn_layer.py | 14 +- python/mxnet/ndarray/__init__.py | 8 +- python/mxnet/ndarray/_internal.py | 1 + python/mxnet/ndarray/contrib.py | 19 ++ python/mxnet/ndarray/linalg.py | 19 ++ python/mxnet/ndarray/ndarray.py | 201 +++++++++--------- python/mxnet/ndarray/op.py | 49 +---- python/mxnet/ndarray/random.py | 19 ++ python/mxnet/ndarray/sparse.py | 40 ++-- python/mxnet/random.py | 8 +- python/mxnet/symbol/__init__.py | 7 +- python/mxnet/symbol/_internal.py | 1 + python/mxnet/symbol/contrib.py | 19 ++ python/mxnet/symbol/linalg.py | 19 ++ python/mxnet/symbol/op.py | 45 +--- python/mxnet/symbol/random.py | 19 ++ python/mxnet/symbol/sparse.py | 1 + python/mxnet/symbol/symbol.py | 96 ++++----- src/operator/random/sample_op.cc | 49 +++-- src/operator/random/sample_op.cu | 4 +- src/operator/tensor/la_op.cc | 49 +++-- src/operator/tensor/la_op.cu | 14 +- src/operator/tensor/matrix_op.cc | 1 + tests/python/gpu/test_operator_gpu.py | 20 +- tests/python/unittest/test_autograd.py | 4 +- tests/python/unittest/test_gluon.py | 2 +- tests/python/unittest/test_gluon_model_zoo.py | 4 +- tests/python/unittest/test_module.py | 28 +-- tests/python/unittest/test_ndarray.py | 2 +- tests/python/unittest/test_operator.py | 56 ++--- tests/python/unittest/test_random.py | 14 +- tests/python/unittest/test_sparse_ndarray.py | 4 +- 51 files changed, 579 insertions(+), 428 deletions(-) create mode 100644 python/mxnet/ndarray/contrib.py create mode 100644 python/mxnet/ndarray/linalg.py create mode 100644 python/mxnet/ndarray/random.py create mode 100644 python/mxnet/symbol/contrib.py create mode 100644 python/mxnet/symbol/linalg.py create mode 100644 python/mxnet/symbol/random.py diff --git a/benchmark/python/sparse/dot.py b/benchmark/python/sparse/dot.py index fe322821a09f..aab34dbec49d 100644 --- a/benchmark/python/sparse/dot.py +++ b/benchmark/python/sparse/dot.py @@ -161,7 +161,7 @@ def run_benchmark(mini_path): weight_row_dim = batch_size if transpose else feature_dim weight_shape = (weight_row_dim, output_dim) if not rsp: - weight = mx.nd.random_uniform(low=0, high=1, shape=weight_shape) + weight = mx.nd.random.uniform(low=0, high=1, shape=weight_shape) else: weight = rand_ndarray(weight_shape, "row_sparse", density=0.05, distribution="uniform") total_cost = {} diff --git a/benchmark/python/sparse/sparse_op.py b/benchmark/python/sparse/sparse_op.py index 0683aa84eacb..ebe62af05da6 100644 --- a/benchmark/python/sparse/sparse_op.py +++ b/benchmark/python/sparse/sparse_op.py @@ -101,7 +101,7 @@ def get_iter(path, data_shape, batch_size): # model data_shape = (k, ) train_iter = get_iter(mini_path, data_shape, batch_size) - weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m)) + weight = mx.nd.random.uniform(low=0, high=1, shape=(k, m)) csr_data = [] dns_data = [] @@ -154,7 +154,7 @@ def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs): def bench_dot_forward(m, k, n, density, ctx, repeat): set_default_context(ctx) - dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx) + dns = mx.nd.random.uniform(shape=(k, n)).copyto(ctx) data_shape = (m, k) csr_data = rand_ndarray(data_shape, 'csr', density) dns_data = csr_data.tostype('default') @@ -183,7 +183,7 @@ def bench_dot_forward(m, k, n, density, ctx, repeat): def bench_dot_backward(m, k, n, density, ctx, repeat): set_default_context(ctx) - dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx) + dns = mx.nd.random.uniform(shape=(m, n)).copyto(ctx) data_shape = (m, k) csr_data = rand_ndarray(data_shape, 'csr', density) dns_data = csr_data.tostype('default') diff --git a/example/ctc/lstm.py b/example/ctc/lstm.py index 7e18c8699492..326daa1d9f3a 100644 --- a/example/ctc/lstm.py +++ b/example/ctc/lstm.py @@ -96,7 +96,7 @@ def lstm_unroll(num_lstm_layer, seq_len, pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0)) - loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label) + loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label) ctc_loss = mx.sym.MakeLoss(loss) softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc) diff --git a/example/gluon/dcgan.py b/example/gluon/dcgan.py index ed814df61e99..3233f430eeac 100644 --- a/example/gluon/dcgan.py +++ b/example/gluon/dcgan.py @@ -184,7 +184,7 @@ def transformer(data, label): ########################### # train with real_t data = data.as_in_context(ctx) - noise = mx.nd.random_normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx) + noise = mx.nd.random.normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx) with autograd.record(): output = netD(data) diff --git a/example/gluon/lstm_crf.py b/example/gluon/lstm_crf.py index 40c8c2be2784..857bfca56186 100644 --- a/example/gluon/lstm_crf.py +++ b/example/gluon/lstm_crf.py @@ -62,13 +62,13 @@ def __init__(self, vocab_size, tag2idx, embedding_dim, hidden_dim): # Matrix of transition parameters. Entry i,j is the score of # transitioning *to* i *from* j. - self.transitions = nd.random_normal(shape=(self.tagset_size, self.tagset_size)) + self.transitions = nd.random.normal(shape=(self.tagset_size, self.tagset_size)) self.hidden = self.init_hidden() def init_hidden(self): - return [nd.random_normal(shape=(2, 1, self.hidden_dim // 2)), - nd.random_normal(shape=(2, 1, self.hidden_dim // 2))] + return [nd.random.normal(shape=(2, 1, self.hidden_dim // 2)), + nd.random.normal(shape=(2, 1, self.hidden_dim // 2))] def _forward_alg(self, feats): # Do the forward algorithm to compute the partition function diff --git a/example/rcnn/rcnn/symbol/symbol_resnet.py b/example/rcnn/rcnn/symbol/symbol_resnet.py index f914d117eb18..4a9677d44099 100644 --- a/example/rcnn/rcnn/symbol/symbol_resnet.py +++ b/example/rcnn/rcnn/symbol/symbol_resnet.py @@ -113,7 +113,7 @@ def get_resnet_train(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCH rpn_cls_act_reshape = mx.symbol.Reshape( data=rpn_cls_act, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_act_reshape') if config.TRAIN.CXX_PROPOSAL: - rois = mx.contrib.symbol.Proposal( + rois = mx.symbol.contrib.Proposal( cls_prob=rpn_cls_act_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois', feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS), rpn_pre_nms_top_n=config.TRAIN.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TRAIN.RPN_POST_NMS_TOP_N, @@ -189,7 +189,7 @@ def get_resnet_test(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHO rpn_cls_prob_reshape = mx.symbol.Reshape( data=rpn_cls_prob, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_prob_reshape') if config.TEST.CXX_PROPOSAL: - rois = mx.contrib.symbol.Proposal( + rois = mx.symbol.contrib.Proposal( cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois', feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS), rpn_pre_nms_top_n=config.TEST.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TEST.RPN_POST_NMS_TOP_N, diff --git a/example/rcnn/rcnn/symbol/symbol_vgg.py b/example/rcnn/rcnn/symbol/symbol_vgg.py index f04ba89dc1d4..00ba15ed8e60 100644 --- a/example/rcnn/rcnn/symbol/symbol_vgg.py +++ b/example/rcnn/rcnn/symbol/symbol_vgg.py @@ -242,7 +242,7 @@ def get_vgg_rpn_test(num_anchors=config.NUM_ANCHORS): rpn_cls_prob_reshape = mx.symbol.Reshape( data=rpn_cls_prob, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_prob_reshape') if config.TEST.CXX_PROPOSAL: - group = mx.contrib.symbol.Proposal( + group = mx.symbol.contrib.Proposal( cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois', output_score=True, feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS), rpn_pre_nms_top_n=config.TEST.PROPOSAL_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TEST.PROPOSAL_POST_NMS_TOP_N, @@ -290,7 +290,7 @@ def get_vgg_test(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS) rpn_cls_prob_reshape = mx.symbol.Reshape( data=rpn_cls_prob, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_prob_reshape') if config.TEST.CXX_PROPOSAL: - rois = mx.contrib.symbol.Proposal( + rois = mx.symbol.contrib.Proposal( cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois', feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS), rpn_pre_nms_top_n=config.TEST.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TEST.RPN_POST_NMS_TOP_N, @@ -373,7 +373,7 @@ def get_vgg_train(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS rpn_cls_act_reshape = mx.symbol.Reshape( data=rpn_cls_act, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_act_reshape') if config.TRAIN.CXX_PROPOSAL: - rois = mx.contrib.symbol.Proposal( + rois = mx.symbol.contrib.Proposal( cls_prob=rpn_cls_act_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois', feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS), rpn_pre_nms_top_n=config.TRAIN.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TRAIN.RPN_POST_NMS_TOP_N, diff --git a/example/ssd/symbol/common.py b/example/ssd/symbol/common.py index ea58c1599add..4a0458f87288 100644 --- a/example/ssd/symbol/common.py +++ b/example/ssd/symbol/common.py @@ -283,8 +283,9 @@ def multibox_layer(from_layers, num_classes, sizes=[.2, .95], step = (steps[k], steps[k]) else: step = '(-1.0, -1.0)' - anchors = mx.contrib.symbol.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, \ - clip=clip, name="{}_anchors".format(from_name), steps=step) + anchors = mx.symbol.contrib.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, + clip=clip, name="{}_anchors".format(from_name), + steps=step) anchors = mx.symbol.Flatten(data=anchors) anchor_layers.append(anchors) diff --git a/example/ssd/symbol/legacy_vgg16_ssd_300.py b/example/ssd/symbol/legacy_vgg16_ssd_300.py index c1f8ea7cb88e..29fc30be65d4 100644 --- a/example/ssd/symbol/legacy_vgg16_ssd_300.py +++ b/example/ssd/symbol/legacy_vgg16_ssd_300.py @@ -144,7 +144,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \ num_channels=num_channels, clip=False, interm_layer=0, steps=steps) - tmp = mx.contrib.symbol.MultiBoxTarget( + tmp = mx.symbol.contrib.MultiBoxTarget( *[anchor_boxes, label, cls_preds], overlap_threshold=.5, \ ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \ negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2), @@ -163,7 +163,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, # monitoring training status cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label") - det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + det = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out") @@ -202,7 +202,7 @@ def get_symbol(num_classes=20, nms_thresh=0.5, force_suppress=False, cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \ name='cls_prob') - out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + out = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) return out diff --git a/example/ssd/symbol/legacy_vgg16_ssd_512.py b/example/ssd/symbol/legacy_vgg16_ssd_512.py index 6cc3aa274a73..c5c3095dfd77 100644 --- a/example/ssd/symbol/legacy_vgg16_ssd_512.py +++ b/example/ssd/symbol/legacy_vgg16_ssd_512.py @@ -148,7 +148,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \ num_channels=num_channels, clip=False, interm_layer=0, steps=steps) - tmp = mx.contrib.symbol.MultiBoxTarget( + tmp = mx.symbol.contrib.MultiBoxTarget( *[anchor_boxes, label, cls_preds], overlap_threshold=.5, \ ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \ negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2), @@ -167,7 +167,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t # monitoring training status cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label") - det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + det = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out") @@ -205,7 +205,7 @@ def get_symbol(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_topk=40 cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \ name='cls_prob') - out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + out = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) return out diff --git a/example/ssd/symbol/symbol_builder.py b/example/ssd/symbol/symbol_builder.py index 4cd7f88ea312..0c7b5c1b14bc 100644 --- a/example/ssd/symbol/symbol_builder.py +++ b/example/ssd/symbol/symbol_builder.py @@ -87,7 +87,7 @@ def get_symbol_train(network, num_classes, from_layers, num_filters, strides, pa num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \ num_channels=num_filters, clip=False, interm_layer=0, steps=steps) - tmp = mx.contrib.symbol.MultiBoxTarget( + tmp = mx.symbol.contrib.MultiBoxTarget( *[anchor_boxes, label, cls_preds], overlap_threshold=.5, \ ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \ negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2), @@ -106,7 +106,7 @@ def get_symbol_train(network, num_classes, from_layers, num_filters, strides, pa # monitoring training status cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label") - det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + det = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out") @@ -177,7 +177,7 @@ def get_symbol(network, num_classes, from_layers, num_filters, sizes, ratios, cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \ name='cls_prob') - out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ + out = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \ name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk) return out diff --git a/example/ssd/tools/caffe_converter/convert_symbol.py b/example/ssd/tools/caffe_converter/convert_symbol.py index 10510aa92569..5ce70230a9a2 100644 --- a/example/ssd/tools/caffe_converter/convert_symbol.py +++ b/example/ssd/tools/caffe_converter/convert_symbol.py @@ -265,7 +265,7 @@ def proto2script(proto_file): finput_dim = float(input_dim[2]) step = '(%f, %f)' % (step_h / finput_dim, step_w / finput_dim) assert param.offset == 0.5, "currently only support offset = 0.5" - symbol_string += '%s = mx.contrib.symbol.MultiBoxPrior(%s, sizes=%s, ratios=%s, clip=%s, steps=%s, name="%s")\n' % \ + symbol_string += '%s = mx.symbol.contrib.MultiBoxPrior(%s, sizes=%s, ratios=%s, clip=%s, steps=%s, name="%s")\n' % \ (name, mapping[layer[i].bottom[0]], sizes, ratios_string, clip, step, name) symbol_string += '%s = mx.symbol.Flatten(data=%s)\n' % (name, name) type_string = 'split' @@ -281,7 +281,7 @@ def proto2script(proto_file): assert param.share_location == True assert param.background_label_id == 0 nms_param = param.nms_param - type_string = 'mx.contrib.symbol.MultiBoxDetection' + type_string = 'mx.symbol.contrib.MultiBoxDetection' param_string = "nms_threshold=%f, nms_topk=%d" % \ (nms_param.nms_threshold, nms_param.top_k) if type_string == '': diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index c2e6fce40de8..0d02c049e398 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -22,22 +22,11 @@ from __future__ import absolute_import as _abs import ctypes -import sys as _sys -import numpy as np from ..base import _LIB -from ..base import c_array, py_str, c_str, mx_uint, _Null -from ..base import NDArrayHandle, OpHandle, CachedOpHandle +from ..base import c_array, c_str +from ..base import NDArrayHandle, CachedOpHandle from ..base import check_call -from ..ndarray_doc import _build_doc - - -_STORAGE_TYPE_ID_TO_STR = { - -1 : 'undefined', - 0 : 'default', - 1 : 'row_sparse', - 2 : 'csr', -} class NDArrayBase(object): @@ -106,10 +95,10 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): return original_output if num_output.value == 1: return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle), - stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]]) + stype=out_stypes[0]) else: return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), - stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]]) + stype=out_stypes[i]) for i in range(num_output.value)] @@ -160,8 +149,8 @@ def __call__(self, *args, **kwargs): return original_output if num_output.value == 1: return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle), - stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]]) + stype=out_stypes[0]) else: return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), - stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]]) + stype=out_stypes[i]) for i in range(num_output.value)] diff --git a/python/mxnet/autograd.py b/python/mxnet/autograd.py index 292bcc2308fc..bca1dc2a96fd 100644 --- a/python/mxnet/autograd.py +++ b/python/mxnet/autograd.py @@ -27,7 +27,8 @@ from .base import _LIB, check_call, string_types from .base import mx_uint, NDArrayHandle, c_array, MXCallbackList, SymbolHandle from .ndarray import NDArray -from .symbol import _GRAD_REQ_MAP, Symbol +from .ndarray import _GRAD_REQ_MAP +from .symbol import Symbol def set_recording(is_recording): #pylint: disable=redefined-outer-name diff --git a/python/mxnet/base.py b/python/mxnet/base.py index d446355da0b5..e422dade6596 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -307,6 +307,7 @@ def _notify_shutdown(): atexit.register(_notify_shutdown) + def add_fileline_to_docstring(module, incursive=True): """Append the definition position to each function contained in module. @@ -342,6 +343,7 @@ def _add_fileline(obj): if inspect.isclass(obj) and incursive: add_fileline_to_docstring(obj, False) + def _as_list(obj): """A utility function that converts the argument to a list if it is not already. @@ -359,3 +361,84 @@ def _as_list(obj): return obj else: return [obj] + + +_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_random_', '_sparse_'] + + +def _get_op_name_prefix(op_name): + """ + Check whether the given op_name starts with any words in `_OP_NAME_PREFIX_LIST`. + If found, return the prefix; else, return an empty string. + """ + for prefix in _OP_NAME_PREFIX_LIST: + if op_name.startswith(prefix): + return prefix + return "" + + +# pylint: enable=too-many-locals, invalid-name +def _init_op_module(root_namespace, module_name, make_op_func): + """ + Registers op functions created by `make_op_func` under + `root_namespace.module_name.[submodule_name]`, + where `submodule_name` is one of `_OP_SUBMODULE_NAME_LIST`. + + Parameters + ---------- + root_namespace : str + Top level module name, `mxnet` in the current cases. + module_name : str + Second level module name, `ndarray` and `symbol` in the current cases. + make_op_func : str + Function for creating op functions for `ndarray` and `symbol` modules. + """ + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.MXListAllOpNames(ctypes.byref(size), + ctypes.byref(plist))) + op_names = [] + for i in range(size.value): + op_names.append(py_str(plist[i])) + + module_op = sys.modules["%s.%s.op" % (root_namespace, module_name)] + module_internal = sys.modules["%s.%s._internal" % (root_namespace, module_name)] + # contrib module in the old format (deprecated) + # kept here for backward compatibility + # use mx.nd.contrib or mx.sym.contrib from now on + contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name) + contrib_module_old = sys.modules[contrib_module_name_old] + submodule_dict = {} + for op_name_prefix in _OP_NAME_PREFIX_LIST: + submodule_dict[op_name_prefix] =\ + sys.modules["%s.%s.%s" % (root_namespace, module_name, op_name_prefix[1:-1])] + for name in op_names: + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + function = make_op_func(hdl, name) + op_name_prefix = _get_op_name_prefix(function.__name__) + if len(op_name_prefix) > 0: + # register op under mxnet.module_name.op_name_prefix[1:-1] + # e.g. mxnet.ndarray.sparse.dot, mxnet.symbol.linalg.gemm + function.__name__ = function.__name__[len(op_name_prefix):] + function.__module__ = "%s.%s.%s" % (root_namespace, module_name, op_name_prefix[1:-1]) + cur_module = submodule_dict[op_name_prefix] + setattr(cur_module, function.__name__, function) + cur_module.__all__.append(function.__name__) + # if op_name_prefix is '_contrib_', also need to register + # the op under mxnet.contrib.module_name for backward compatibility + if op_name_prefix == '_contrib_': + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + function = make_op_func(hdl, name) + function.__name__ = function.__name__[len(op_name_prefix):] + function.__module__ = contrib_module_name_old + setattr(contrib_module_old, function.__name__, function) + contrib_module_old.__all__.append(function.__name__) + elif function.__name__.startswith('_'): + setattr(module_internal, function.__name__, function) + module_internal.__all__.append(function.__name__) + else: + setattr(module_op, function.__name__, function) + module_op.__all__.append(function.__name__) diff --git a/python/mxnet/contrib/autograd.py b/python/mxnet/contrib/autograd.py index 2d2500e7a217..68ce31bb0506 100644 --- a/python/mxnet/contrib/autograd.py +++ b/python/mxnet/contrib/autograd.py @@ -25,8 +25,7 @@ from ..base import _LIB, check_call, string_types from ..base import mx_uint, NDArrayHandle, c_array # pylint: disable= unused-import -from ..ndarray import NDArray, zeros_like -from ..symbol import _GRAD_REQ_MAP +from ..ndarray import NDArray, zeros_like, _GRAD_REQ_MAP def set_is_training(is_train): diff --git a/python/mxnet/contrib/ndarray.py b/python/mxnet/contrib/ndarray.py index 3c86fe7ba3fb..f65c75ef0fef 100644 --- a/python/mxnet/contrib/ndarray.py +++ b/python/mxnet/contrib/ndarray.py @@ -17,3 +17,4 @@ # coding: utf-8 """NDArray namespace used to register contrib functions""" +__all__ = [] diff --git a/python/mxnet/contrib/symbol.py b/python/mxnet/contrib/symbol.py index 1d5334595f27..90f6dae070d2 100644 --- a/python/mxnet/contrib/symbol.py +++ b/python/mxnet/contrib/symbol.py @@ -17,3 +17,4 @@ # coding: utf-8 """Symbol namespace used to register contrib functions""" +__all__ = [] diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 74a9058e98e0..d6114814b9f6 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -459,7 +459,7 @@ class SymbolBlock(HybridBlock): internals['model_dense1_relu_fwd_output']] >>> # Create SymbolBlock that shares parameters with alexnet >>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params()) - >>> x = mx.nd.random_normal(shape=(16, 3, 224, 224)) + >>> x = mx.nd.random.normal(shape=(16, 3, 224, 224)) >>> print(feat_model(x)) """ def __init__(self, outputs, inputs, params=None): diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 063d56654f9f..b280752811d1 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -301,11 +301,11 @@ class RNN(_RNNLayer): -------- >>> layer = mx.gluon.rnn.RNN(100, 3) >>> layer.initialize() - >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) + >>> input = mx.nd.random.uniform(shape=(5, 3, 10)) >>> # by default zeros are used as begin state >>> output = layer(input) >>> # manually specify begin state. - >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) + >>> h0 = mx.nd.random.uniform(shape=(3, 3, 100)) >>> output, hn = layer(input, h0) """ def __init__(self, hidden_size, num_layers=1, activation='relu', @@ -404,12 +404,12 @@ class LSTM(_RNNLayer): -------- >>> layer = mx.gluon.rnn.LSTM(100, 3) >>> layer.initialize() - >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) + >>> input = mx.nd.random.uniform(shape=(5, 3, 10)) >>> # by default zeros are used as begin state >>> output = layer(input) >>> # manually specify begin state. - >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) - >>> c0 = mx.nd.random_uniform(shape=(3, 3, 100)) + >>> h0 = mx.nd.random.uniform(shape=(3, 3, 100)) + >>> c0 = mx.nd.random.uniform(shape=(3, 3, 100)) >>> output, hn = layer(input, [h0, c0]) """ def __init__(self, hidden_size, num_layers=1, layout='TNC', @@ -503,11 +503,11 @@ class GRU(_RNNLayer): -------- >>> layer = mx.gluon.rnn.GRU(100, 3) >>> layer.initialize() - >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) + >>> input = mx.nd.random.uniform(shape=(5, 3, 10)) >>> # by default zeros are used as begin state >>> output = layer(input) >>> # manually specify begin state. - >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) + >>> h0 = mx.nd.random.uniform(shape=(3, 3, 100)) >>> output, hn = layer(input, h0) """ def __init__(self, hidden_size, num_layers=1, layout='TNC', diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index 63220787a43c..43ec961afa39 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -17,9 +17,13 @@ """NDArray API of MXNet.""" -from . import _internal, sparse, op -from .op import CachedOp +from . import _internal, contrib, linalg, random, sparse # pylint: disable=wildcard-import, redefined-builtin +from .op import * from .ndarray import * +# pylint: enable=wildcard-import from .utils import load, save, zeros, empty, array from .sparse import _ndarray_cls +from .ndarray import _GRAD_REQ_MAP + +__all__ = op.__all__ + ndarray.__all__ + ['contrib', 'linalg', 'random', 'sparse'] diff --git a/python/mxnet/ndarray/_internal.py b/python/mxnet/ndarray/_internal.py index 8f151f1b5b64..96eff31b72dc 100644 --- a/python/mxnet/ndarray/_internal.py +++ b/python/mxnet/ndarray/_internal.py @@ -16,3 +16,4 @@ # under the License. """NDArray namespace used to register internal functions.""" +__all__ = [] diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py new file mode 100644 index 000000000000..f21d144c4b63 --- /dev/null +++ b/python/mxnet/ndarray/contrib.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Contrib NDArray API of MXNet.""" +__all__ = [] diff --git a/python/mxnet/ndarray/linalg.py b/python/mxnet/ndarray/linalg.py new file mode 100644 index 000000000000..0c8e7fd57a58 --- /dev/null +++ b/python/mxnet/ndarray/linalg.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Linear Algebra NDArray API of MXNet.""" +__all__ = [] diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index a85ccb5b6076..b0500d31b323 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -37,16 +37,8 @@ from ..base import ctypes2buffer from ..context import Context from . import _internal -from .op import NDArrayBase, _STORAGE_TYPE_ID_TO_STR -from . import cast_storage -from . import broadcast_add, broadcast_mul, transpose, broadcast_not_equal, broadcast_power -from . import broadcast_sub, broadcast_div, broadcast_to, broadcast_axes, broadcast_equal -from . import broadcast_greater, broadcast_greater_equal, broadcast_lesser, broadcast_lesser_equal -from . import zeros_like, ones_like, broadcast_minimum, broadcast_maximum, broadcast_mod -from . import flatten, norm, rint, fix, floor, ceil, split, slice_axis, one_hot, pick, take -from . import trunc, expand_dims, flip, tile, repeat, pad, clip, sign -from . import nansum, prod, nanprod, mean, sort, topk, argsort, argmax, argmin -from . import sum, round, max, min, slice, abs # pylint: disable=redefined-builtin +from . import op +from .op import NDArrayBase __all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP", "ones", "add", "arange", "divide", "equal", "full", "greater", "greater_equal", @@ -54,33 +46,48 @@ "multiply", "negative", "not_equal", "onehot_encode", "power", "subtract", "true_divide", "waitall", "_new_empty_handle"] +_STORAGE_TYPE_UNDEFINED = -1 +_STORAGE_TYPE_DEFAULT = 0 +_STORAGE_TYPE_ROW_SPARSE = 1 +_STORAGE_TYPE_CSR = 2 + # pylint: disable= no-member _DTYPE_NP_TO_MX = { - None : -1, - np.float32 : 0, - np.float64 : 1, - np.float16 : 2, - np.uint8 : 3, - np.int32 : 4, - np.int8 : 5, - np.int64 : 6, + None: -1, + np.float32: 0, + np.float64: 1, + np.float16: 2, + np.uint8: 3, + np.int32: 4, + np.int8: 5, + np.int64: 6, } + _DTYPE_MX_TO_NP = { - -1 : None, - 0 : np.float32, - 1 : np.float64, - 2 : np.float16, - 3 : np.uint8, - 4 : np.int32, - 5 : np.int8, - 6 : np.int64, + -1: None, + 0: np.float32, + 1: np.float64, + 2: np.float16, + 3: np.uint8, + 4: np.int32, + 5: np.int8, + 6: np.int64, } + _STORAGE_TYPE_STR_TO_ID = { - 'undefined' : -1, - 'default' : 0, - 'row_sparse' : 1, - 'csr' : 2, + 'undefined': _STORAGE_TYPE_UNDEFINED, + 'default': _STORAGE_TYPE_DEFAULT, + 'row_sparse': _STORAGE_TYPE_ROW_SPARSE, + 'csr': _STORAGE_TYPE_CSR, } + +_STORAGE_TYPE_ID_TO_STR = { + _STORAGE_TYPE_UNDEFINED: 'undefined', + _STORAGE_TYPE_DEFAULT: 'default', + _STORAGE_TYPE_ROW_SPARSE: 'row_sparse', + _STORAGE_TYPE_CSR: 'csr', +} + _GRAD_REQ_MAP = { 'null': 0, 'write': 1, @@ -137,7 +144,7 @@ def waitall(): def _storage_type(handle): storage_type = ctypes.c_int(0) check_call(_LIB.MXNDArrayGetStorageType(handle, ctypes.byref(storage_type))) - return _STORAGE_TYPE_ID_TO_STR[storage_type.value] + return storage_type.value class NDArray(NDArrayBase): @@ -169,7 +176,7 @@ def __iadd__(self, other): if not self.writable: raise ValueError('trying to add to a readonly NDArray') if isinstance(other, NDArray): - return broadcast_add(self, other, out=self) + return op.broadcast_add(self, other, out=self) elif isinstance(other, numeric_types): return _internal._plus_scalar(self, float(other), out=self) else: @@ -187,7 +194,7 @@ def __isub__(self, other): if not self.writable: raise ValueError('trying to subtract from a readonly NDArray') if isinstance(other, NDArray): - return broadcast_sub(self, other, out=self) + return op.broadcast_sub(self, other, out=self) elif isinstance(other, numeric_types): return _internal._minus_scalar(self, float(other), out=self) else: @@ -210,7 +217,7 @@ def __imul__(self, other): if not self.writable: raise ValueError('trying to multiply to a readonly NDArray') if isinstance(other, NDArray): - return broadcast_mul(self, other, out=self) + return op.broadcast_mul(self, other, out=self) elif isinstance(other, numeric_types): return _internal._mul_scalar(self, float(other), out=self) else: @@ -232,7 +239,7 @@ def __idiv__(self, other): if not self.writable: raise ValueError('trying to divide from a readonly NDArray') if isinstance(other, NDArray): - return broadcast_div(self, other, out=self) + return op.broadcast_div(self, other, out=self) elif isinstance(other, numeric_types): return _internal._div_scalar(self, float(other), out=self) else: @@ -260,7 +267,7 @@ def __imod__(self, other): if not self.writable: raise ValueError('trying to take modulo from a readonly NDArray') if isinstance(other, NDArray): - return broadcast_mod(self, other, out=self) + return op.broadcast_mod(self, other, out=self) elif isinstance(other, numeric_types): return _internal._mod_scalar(self, float(other), out=self) else: @@ -525,7 +532,7 @@ def __getitem__(self, key): oshape.extend(shape[i+1:]) if len(oshape) == 0: oshape.append(1) - return slice(self, begin, end).reshape(oshape) + return op.slice(self, begin, end).reshape(oshape) else: raise ValueError( "NDArray does not support slicing with key %s of type %s."%( @@ -684,7 +691,7 @@ def zeros_like(self, *args, **kwargs): The arguments are the same as for :py:func:`zeros_like`, with this array as data. """ - return zeros_like(self, *args, **kwargs) + return op.zeros_like(self, *args, **kwargs) def ones_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`ones_like`. @@ -692,7 +699,7 @@ def ones_like(self, *args, **kwargs): The arguments are the same as for :py:func:`ones_like`, with this array as data. """ - return ones_like(self, *args, **kwargs) + return op.ones_like(self, *args, **kwargs) def broadcast_axes(self, *args, **kwargs): """Convenience fluent method for :py:func:`broadcast_axes`. @@ -700,7 +707,7 @@ def broadcast_axes(self, *args, **kwargs): The arguments are the same as for :py:func:`broadcast_axes`, with this array as data. """ - return broadcast_axes(self, *args, **kwargs) + return op.broadcast_axes(self, *args, **kwargs) def repeat(self, *args, **kwargs): """Convenience fluent method for :py:func:`repeat`. @@ -708,7 +715,7 @@ def repeat(self, *args, **kwargs): The arguments are the same as for :py:func:`repeat`, with this array as data. """ - return repeat(self, *args, **kwargs) + return op.repeat(self, *args, **kwargs) def pad(self, *args, **kwargs): """Convenience fluent method for :py:func:`pad`. @@ -716,7 +723,7 @@ def pad(self, *args, **kwargs): The arguments are the same as for :py:func:`pad`, with this array as data. """ - return pad(self, *args, **kwargs) + return op.pad(self, *args, **kwargs) def swapaxes(self, *args, **kwargs): """Convenience fluent method for :py:func:`swapaxes`. @@ -724,7 +731,7 @@ def swapaxes(self, *args, **kwargs): The arguments are the same as for :py:func:`swapaxes`, with this array as data. """ - return swapaxes(self, *args, **kwargs) + return op.swapaxes(self, *args, **kwargs) def split(self, *args, **kwargs): """Convenience fluent method for :py:func:`split`. @@ -732,7 +739,7 @@ def split(self, *args, **kwargs): The arguments are the same as for :py:func:`split`, with this array as data. """ - return split(self, *args, **kwargs) + return op.split(self, *args, **kwargs) def slice(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice`. @@ -740,7 +747,7 @@ def slice(self, *args, **kwargs): The arguments are the same as for :py:func:`slice`, with this array as data. """ - return slice(self, *args, **kwargs) + return op.slice(self, *args, **kwargs) def slice_axis(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice_axis`. @@ -748,7 +755,7 @@ def slice_axis(self, *args, **kwargs): The arguments are the same as for :py:func:`slice_axis`, with this array as data. """ - return slice_axis(self, *args, **kwargs) + return op.slice_axis(self, *args, **kwargs) def take(self, *args, **kwargs): """Convenience fluent method for :py:func:`take`. @@ -756,7 +763,7 @@ def take(self, *args, **kwargs): The arguments are the same as for :py:func:`take`, with this array as data. """ - return take(self, *args, **kwargs) + return op.take(self, *args, **kwargs) def one_hot(self, *args, **kwargs): """Convenience fluent method for :py:func:`one_hot`. @@ -764,7 +771,7 @@ def one_hot(self, *args, **kwargs): The arguments are the same as for :py:func:`one_hot`, with this array as data. """ - return one_hot(self, *args, **kwargs) + return op.one_hot(self, *args, **kwargs) def pick(self, *args, **kwargs): """Convenience fluent method for :py:func:`pick`. @@ -772,7 +779,7 @@ def pick(self, *args, **kwargs): The arguments are the same as for :py:func:`pick`, with this array as data. """ - return pick(self, *args, **kwargs) + return op.pick(self, *args, **kwargs) def sort(self, *args, **kwargs): """Convenience fluent method for :py:func:`sort`. @@ -780,7 +787,7 @@ def sort(self, *args, **kwargs): The arguments are the same as for :py:func:`sort`, with this array as data. """ - return sort(self, *args, **kwargs) + return op.sort(self, *args, **kwargs) def topk(self, *args, **kwargs): """Convenience fluent method for :py:func:`topk`. @@ -788,7 +795,7 @@ def topk(self, *args, **kwargs): The arguments are the same as for :py:func:`topk`, with this array as data. """ - return topk(self, *args, **kwargs) + return op.topk(self, *args, **kwargs) def argsort(self, *args, **kwargs): """Convenience fluent method for :py:func:`argsort`. @@ -796,7 +803,7 @@ def argsort(self, *args, **kwargs): The arguments are the same as for :py:func:`argsort`, with this array as data. """ - return argsort(self, *args, **kwargs) + return op.argsort(self, *args, **kwargs) def argmax(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmax`. @@ -804,7 +811,7 @@ def argmax(self, *args, **kwargs): The arguments are the same as for :py:func:`argmax`, with this array as data. """ - return argmax(self, *args, **kwargs) + return op.argmax(self, *args, **kwargs) def argmin(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmin`. @@ -812,7 +819,7 @@ def argmin(self, *args, **kwargs): The arguments are the same as for :py:func:`argmin`, with this array as data. """ - return argmin(self, *args, **kwargs) + return op.argmin(self, *args, **kwargs) def clip(self, *args, **kwargs): """Convenience fluent method for :py:func:`clip`. @@ -820,7 +827,7 @@ def clip(self, *args, **kwargs): The arguments are the same as for :py:func:`clip`, with this array as data. """ - return clip(self, *args, **kwargs) + return op.clip(self, *args, **kwargs) def abs(self, *args, **kwargs): """Convenience fluent method for :py:func:`abs`. @@ -828,7 +835,7 @@ def abs(self, *args, **kwargs): The arguments are the same as for :py:func:`abs`, with this array as data. """ - return abs(self, *args, **kwargs) + return op.abs(self, *args, **kwargs) def sign(self, *args, **kwargs): """Convenience fluent method for :py:func:`sign`. @@ -836,7 +843,7 @@ def sign(self, *args, **kwargs): The arguments are the same as for :py:func:`sign`, with this array as data. """ - return sign(self, *args, **kwargs) + return op.sign(self, *args, **kwargs) def flatten(self, *args, **kwargs): """Convenience fluent method for :py:func:`flatten`. @@ -844,7 +851,7 @@ def flatten(self, *args, **kwargs): The arguments are the same as for :py:func:`flatten`, with this array as data. """ - return flatten(self, *args, **kwargs) + return op.flatten(self, *args, **kwargs) def expand_dims(self, *args, **kwargs): """Convenience fluent method for :py:func:`expand_dims`. @@ -852,7 +859,7 @@ def expand_dims(self, *args, **kwargs): The arguments are the same as for :py:func:`expand_dims`, with this array as data. """ - return expand_dims(self, *args, **kwargs) + return op.expand_dims(self, *args, **kwargs) def tile(self, *args, **kwargs): """Convenience fluent method for :py:func:`tile`. @@ -860,7 +867,7 @@ def tile(self, *args, **kwargs): The arguments are the same as for :py:func:`tile`, with this array as data. """ - return tile(self, *args, **kwargs) + return op.tile(self, *args, **kwargs) def transpose(self, *args, **kwargs): """Convenience fluent method for :py:func:`transpose`. @@ -868,7 +875,7 @@ def transpose(self, *args, **kwargs): The arguments are the same as for :py:func:`transpose`, with this array as data. """ - return transpose(self, *args, **kwargs) + return op.transpose(self, *args, **kwargs) def flip(self, *args, **kwargs): """Convenience fluent method for :py:func:`flip`. @@ -876,7 +883,7 @@ def flip(self, *args, **kwargs): The arguments are the same as for :py:func:`flip`, with this array as data. """ - return flip(self, *args, **kwargs) + return op.flip(self, *args, **kwargs) def sum(self, *args, **kwargs): """Convenience fluent method for :py:func:`sum`. @@ -884,7 +891,7 @@ def sum(self, *args, **kwargs): The arguments are the same as for :py:func:`sum`, with this array as data. """ - return sum(self, *args, **kwargs) + return op.sum(self, *args, **kwargs) def nansum(self, *args, **kwargs): """Convenience fluent method for :py:func:`nansum`. @@ -892,7 +899,7 @@ def nansum(self, *args, **kwargs): The arguments are the same as for :py:func:`nansum`, with this array as data. """ - return nansum(self, *args, **kwargs) + return op.nansum(self, *args, **kwargs) def prod(self, *args, **kwargs): """Convenience fluent method for :py:func:`prod`. @@ -900,7 +907,7 @@ def prod(self, *args, **kwargs): The arguments are the same as for :py:func:`prod`, with this array as data. """ - return prod(self, *args, **kwargs) + return op.prod(self, *args, **kwargs) def nanprod(self, *args, **kwargs): """Convenience fluent method for :py:func:`nanprod`. @@ -908,7 +915,7 @@ def nanprod(self, *args, **kwargs): The arguments are the same as for :py:func:`nanprod`, with this array as data. """ - return nanprod(self, *args, **kwargs) + return op.nanprod(self, *args, **kwargs) def mean(self, *args, **kwargs): """Convenience fluent method for :py:func:`mean`. @@ -916,7 +923,7 @@ def mean(self, *args, **kwargs): The arguments are the same as for :py:func:`mean`, with this array as data. """ - return mean(self, *args, **kwargs) + return op.mean(self, *args, **kwargs) def max(self, *args, **kwargs): """Convenience fluent method for :py:func:`max`. @@ -924,7 +931,7 @@ def max(self, *args, **kwargs): The arguments are the same as for :py:func:`max`, with this array as data. """ - return max(self, *args, **kwargs) + return op.max(self, *args, **kwargs) def min(self, *args, **kwargs): """Convenience fluent method for :py:func:`min`. @@ -932,7 +939,7 @@ def min(self, *args, **kwargs): The arguments are the same as for :py:func:`min`, with this array as data. """ - return min(self, *args, **kwargs) + return op.min(self, *args, **kwargs) def norm(self, *args, **kwargs): """Convenience fluent method for :py:func:`norm`. @@ -940,7 +947,7 @@ def norm(self, *args, **kwargs): The arguments are the same as for :py:func:`norm`, with this array as data. """ - return norm(self, *args, **kwargs) + return op.norm(self, *args, **kwargs) def round(self, *args, **kwargs): """Convenience fluent method for :py:func:`round`. @@ -948,7 +955,7 @@ def round(self, *args, **kwargs): The arguments are the same as for :py:func:`round`, with this array as data. """ - return round(self, *args, **kwargs) + return op.round(self, *args, **kwargs) def rint(self, *args, **kwargs): """Convenience fluent method for :py:func:`rint`. @@ -956,7 +963,7 @@ def rint(self, *args, **kwargs): The arguments are the same as for :py:func:`rint`, with this array as data. """ - return rint(self, *args, **kwargs) + return op.rint(self, *args, **kwargs) def fix(self, *args, **kwargs): """Convenience fluent method for :py:func:`fix`. @@ -964,7 +971,7 @@ def fix(self, *args, **kwargs): The arguments are the same as for :py:func:`fix`, with this array as data. """ - return fix(self, *args, **kwargs) + return op.fix(self, *args, **kwargs) def floor(self, *args, **kwargs): """Convenience fluent method for :py:func:`floor`. @@ -972,7 +979,7 @@ def floor(self, *args, **kwargs): The arguments are the same as for :py:func:`floor`, with this array as data. """ - return floor(self, *args, **kwargs) + return op.floor(self, *args, **kwargs) def ceil(self, *args, **kwargs): """Convenience fluent method for :py:func:`ceil`. @@ -980,7 +987,7 @@ def ceil(self, *args, **kwargs): The arguments are the same as for :py:func:`ceil`, with this array as data. """ - return ceil(self, *args, **kwargs) + return op.ceil(self, *args, **kwargs) def trunc(self, *args, **kwargs): """Convenience fluent method for :py:func:`trunc`. @@ -988,7 +995,7 @@ def trunc(self, *args, **kwargs): The arguments are the same as for :py:func:`trunc`, with this array as data. """ - return trunc(self, *args, **kwargs) + return op.trunc(self, *args, **kwargs) # pylint: disable= undefined-variable def broadcast_to(self, shape): @@ -1038,9 +1045,9 @@ def broadcast_to(self, shape): if (cur_shape_arr[broadcasting_axes] != 1).any(): raise ValueError(err_str) if cur_shape != self.shape: - return broadcast_to(self.reshape(cur_shape), shape=shape) + return op.broadcast_to(self.reshape(cur_shape), shape=shape) else: - return broadcast_to(self, shape=tuple(shape)) + return op.broadcast_to(self, shape=tuple(shape)) # pylint: enable= undefined-variable def wait_to_read(self): @@ -1166,7 +1173,7 @@ def dtype(self): def stype(self): """Storage-type of the array. """ - return _storage_type(self.handle) + return _STORAGE_TYPE_ID_TO_STR[_storage_type(self.handle)] @property # pylint: disable= invalid-name, undefined-variable @@ -1193,7 +1200,7 @@ def T(self): """ if len(self.shape) < 2: return self - return transpose(self) + return op.transpose(self) # pylint: enable= invalid-name, undefined-variable @property @@ -1381,7 +1388,7 @@ def attach_grad(self, grad_req='write'): - 'add': gradient will be added to existing value on every backward. - 'null': do not compute gradient for this NDArray. """ - grad = zeros_like(self) # pylint: disable=undefined-variable + grad = op.zeros_like(self) # pylint: disable=undefined-variable grad_req = _GRAD_REQ_MAP[grad_req] check_call(_LIB.MXAutogradMarkVariables( 1, ctypes.pointer(self.handle), @@ -1440,7 +1447,7 @@ def tostype(self, stype): NDArray, CSRNDArray or RowSparseNDArray A copy of the array with the chosen storage stype """ - return cast_storage(self, stype=stype) + return op.cast_storage(self, stype=stype) def onehot_encode(indices, out): @@ -1595,7 +1602,7 @@ def moveaxis(tensor, source, destination): except IndexError: raise ValueError('Destination should verify 0 <= destination < tensor.ndim' 'Got %d' % destination) - return transpose(tensor, axes) + return op.transpose(tensor, axes) # pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name @@ -1751,7 +1758,7 @@ def add(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_add, + op.broadcast_add, operator.add, _internal._plus_scalar, None) @@ -1813,7 +1820,7 @@ def subtract(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_sub, + op.broadcast_sub, operator.sub, _internal._minus_scalar, _internal._rminus_scalar) @@ -1874,7 +1881,7 @@ def multiply(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_mul, + op.broadcast_mul, operator.mul, _internal._mul_scalar, None) @@ -1931,7 +1938,7 @@ def divide(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_div, + op.broadcast_div, operator.truediv, _internal._div_scalar, _internal._rdiv_scalar) @@ -1988,7 +1995,7 @@ def modulo(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_mod, + op.broadcast_mod, operator.mod, _internal._mod_scalar, _internal._rmod_scalar) @@ -2050,7 +2057,7 @@ def power(base, exp): return _ufunc_helper( base, exp, - broadcast_power, + op.broadcast_power, operator.pow, _internal._power_scalar, _internal._rpower_scalar) @@ -2107,7 +2114,7 @@ def maximum(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_maximum, + op.broadcast_maximum, lambda x, y: x if x > y else y, _internal._maximum_scalar, None) @@ -2164,7 +2171,7 @@ def minimum(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_minimum, + op.broadcast_minimum, lambda x, y: x if x < y else y, _internal._minimum_scalar, None) @@ -2228,7 +2235,7 @@ def equal(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_equal, + op.broadcast_equal, lambda x, y: 1 if x == y else 0, _internal._equal_scalar, None) @@ -2295,7 +2302,7 @@ def not_equal(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_not_equal, + op.broadcast_not_equal, lambda x, y: 1 if x != y else 0, _internal._not_equal_scalar, None) @@ -2359,7 +2366,7 @@ def greater(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_greater, + op.broadcast_greater, lambda x, y: 1 if x > y else 0, _internal._greater_scalar, _internal._lesser_scalar) @@ -2423,7 +2430,7 @@ def greater_equal(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_greater_equal, + op.broadcast_greater_equal, lambda x, y: 1 if x >= y else 0, _internal._greater_equal_scalar, _internal._lesser_equal_scalar) @@ -2487,7 +2494,7 @@ def lesser(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_lesser, + op.broadcast_lesser, lambda x, y: 1 if x < y else 0, _internal._lesser_scalar, _internal._greater_scalar) @@ -2551,7 +2558,7 @@ def lesser_equal(lhs, rhs): return _ufunc_helper( lhs, rhs, - broadcast_lesser_equal, + op.broadcast_lesser_equal, lambda x, y: 1 if x <= y else 0, _internal._lesser_equal_scalar, _internal._greater_equal_scalar) diff --git a/python/mxnet/ndarray/op.py b/python/mxnet/ndarray/op.py index e4a1ab0df48b..27ce09b69b49 100644 --- a/python/mxnet/ndarray/op.py +++ b/python/mxnet/ndarray/op.py @@ -16,6 +16,7 @@ # under the License. """Register backend ops in mxnet.ndarray namespace""" +__all__ = ['CachedOp'] import sys as _sys import os as _os @@ -29,21 +30,21 @@ # pylint: disable=unused-import try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: - from .._ctypes.ndarray import NDArrayBase, _STORAGE_TYPE_ID_TO_STR + from .._ctypes.ndarray import NDArrayBase from .._ctypes.ndarray import CachedOp, _imperative_invoke elif _sys.version_info >= (3, 0): - from .._cy3.ndarray import NDArrayBase, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._cy3.ndarray import NDArrayBase, _imperative_invoke from .._cy3.ndarray import CachedOp, _imperative_invoke else: - from .._cy2.ndarray import NDArrayBase, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._cy2.ndarray import NDArrayBase, _imperative_invoke from .._cy2.ndarray import CachedOp, _imperative_invoke except ImportError: if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") - from .._ctypes.ndarray import NDArrayBase, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._ctypes.ndarray import NDArrayBase, _imperative_invoke from .._ctypes.ndarray import CachedOp, _imperative_invoke -from ..base import mx_uint, check_call, _LIB, py_str, OpHandle, c_str, _Null +from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null # pylint: enable=unused-import @@ -170,40 +171,4 @@ def %s(%s): return ndarray_function -# pylint: enable=too-many-locals, invalid-name -def _init_ndarray_module(root_namespace): - """List and add all the ndarray functions to current module.""" - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist))) - op_names = [] - for i in range(size.value): - op_names.append(py_str(plist[i])) - - module_obj = _sys.modules["%s.ndarray" % root_namespace] - module_sparse = _sys.modules["%s.ndarray.sparse" % root_namespace] - module_internal = _sys.modules["%s.ndarray._internal" % root_namespace] - module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] - for name in op_names: - hdl = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - function = _make_ndarray_function(hdl, name) - if function.__name__.startswith('_contrib_'): - function.__name__ = function.__name__[9:] - function.__module__ = 'mxnet.contrib.ndarray' - setattr(module_contrib, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) - - # register sparse ops under mxnet.ndarray.sparse - if function.__name__.startswith('_sparse_'): - function.__name__ = function.__name__[8:] - function.__module__ = 'mxnet.ndarray.sparse' - setattr(module_sparse, function.__name__, function) - -# register backend operators in mx.nd -_init_ndarray_module("mxnet") +_init_op_module('mxnet', 'ndarray', _make_ndarray_function) diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py new file mode 100644 index 000000000000..0ec4578ba3bf --- /dev/null +++ b/python/mxnet/ndarray/random.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Random distribution generator NDArray API of MXNet.""" +__all__ = [] diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py index 97e43f5ebe79..806398ea3ec7 100644 --- a/python/mxnet/ndarray/sparse.py +++ b/python/mxnet/ndarray/sparse.py @@ -31,6 +31,9 @@ import os as _os import sys as _sys +__all__ = ["_ndarray_cls", "csr_matrix", "row_sparse_array", + "BaseSparseNDArray", "CSRNDArray", "RowSparseNDArray"] + # import operator import numpy as np from ..base import NotSupportedForSparseNDArray @@ -41,11 +44,12 @@ from . import _internal from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from .ndarray import _STORAGE_TYPE_STR_TO_ID +from .ndarray import _STORAGE_TYPE_UNDEFINED, _STORAGE_TYPE_DEFAULT +from .ndarray import _STORAGE_TYPE_ROW_SPARSE, _STORAGE_TYPE_CSR from .ndarray import NDArray, _storage_type from .ndarray import zeros as _zeros_ndarray from .ndarray import array as _array -from . import cast_storage -from . import slice as nd_slice +from . import op # Use different verison of SymbolBase # When possible, use cython to speedup part of computation. @@ -64,10 +68,6 @@ # pylint: enable=unused-import -__all__ = ["_ndarray_cls", "csr_matrix", "row_sparse_array", - "BaseSparseNDArray", "CSRNDArray", "RowSparseNDArray"] - - _STORAGE_AUX_TYPES = { 'row_sparse': [np.int64], 'csr': [np.int64, np.int64] @@ -300,7 +300,7 @@ def __getitem__(self, key): >>> indptr = np.array([0, 2, 3, 6]) >>> indices = np.array([0, 2, 2, 0, 1, 2]) >>> data = np.array([1, 2, 3, 4, 5, 6]) - >>> a = mx.nd.csr_matrix(data, indptr, indices, (3, 3)) + >>> a = mx.nd.sparse.csr_matrix(data, indptr, indices, (3, 3)) >>> a.asnumpy() array([[1, 0, 2], [0, 0, 3], @@ -316,7 +316,7 @@ def __getitem__(self, key): if key.start is not None or key.stop is not None: begin = key.start if key.start else 0 end = key.stop if key.stop else self.shape[0] - return nd_slice(self, begin=begin, end=end) + return op.slice(self, begin=begin, end=end) else: return self if isinstance(key, tuple): @@ -427,7 +427,7 @@ def tostype(self, stype): """ if stype == 'row_sparse': raise ValueError("cast_storage from csr to row_sparse is not supported") - return cast_storage(self, stype=stype) + return op.cast_storage(self, stype=stype) def copyto(self, other): """Copies the value of this array to another array. @@ -640,7 +640,7 @@ def tostype(self, stype): """ if stype == 'csr': raise ValueError("cast_storage from row_sparse to csr is not supported") - return cast_storage(self, stype=stype) + return op.cast_storage(self, stype=stype) def copyto(self, other): """Copies the value of this array to another array. @@ -725,7 +725,7 @@ def csr_matrix(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=N Example ------- >>> import mxnet as mx - >>> a = mx.nd.csr_matrix([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3)) + >>> a = mx.nd.sparse.csr_matrix([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3)) >>> a.asnumpy() array([[ 0., 1., 0.], [ 2., 0., 0.], @@ -794,7 +794,7 @@ def row_sparse_array(data, indices, shape, ctx=None, dtype=None, indices_type=No Example ------- - >>> a = mx.nd.row_sparse_array([[1, 2], [3, 4]], [1, 4], (6, 2)) + >>> a = mx.nd.sparse.row_sparse_array([[1, 2], [3, 4]], [1, 4], (6, 2)) >>> a.asnumpy() array([[ 0., 0.], [ 1., 2.], @@ -831,14 +831,14 @@ def row_sparse_array(data, indices, shape, ctx=None, dtype=None, indices_type=No return result -def _ndarray_cls(handle, writable=True, stype=None): - if stype is None: +def _ndarray_cls(handle, writable=True, stype=_STORAGE_TYPE_UNDEFINED): + if stype == _STORAGE_TYPE_UNDEFINED: stype = _storage_type(handle) - if stype == 'default': + if stype == _STORAGE_TYPE_DEFAULT: return NDArray(handle, writable=writable) - elif stype == 'csr': + elif stype == _STORAGE_TYPE_CSR: return CSRNDArray(handle, writable=writable) - elif stype == 'row_sparse': + elif stype == _STORAGE_TYPE_ROW_SPARSE: return RowSparseNDArray(handle, writable=writable) else: raise Exception("unknown storage type") @@ -910,13 +910,13 @@ def array(source_array, ctx=None, dtype=None, aux_types=None): """Creates a sparse array from any object exposing the array interface. """ if isinstance(source_array, NDArray): - assert(source_array.stype != 'default'), \ - "Please use `cast_storage` to create BaseSparseNDArray from an NDArray" + assert(source_array.stype != 'default'),\ + "Please use `cast_storage` to create BaseSparseNDArray from an NDArray" dtype = source_array.dtype if dtype is None else dtype aux_types = source_array._aux_types if aux_types is None else aux_types else: # TODO(haibin/anisub) support creation from scipy object when `_sync_copy_from` is ready - raise NotImplementedError('creating BaseSparseNDArray from ' \ + raise NotImplementedError('creating BaseSparseNDArray from ' ' a non-NDArray object is not implemented.') arr = empty(source_array.stype, source_array.shape, ctx, dtype, aux_types) arr[:] = source_array diff --git a/python/mxnet/random.py b/python/mxnet/random.py index 14bfc2731bd6..5754304bfc68 100644 --- a/python/mxnet/random.py +++ b/python/mxnet/random.py @@ -48,19 +48,19 @@ def seed(seed_state): Example ------- - >>> print(mx.nd.random_normal(shape=(2,2)).asnumpy()) + >>> print(mx.nd.random.normal(shape=(2,2)).asnumpy()) [[ 1.36481571 -0.62203991] [-1.4962182 -0.08511394]] - >>> print(mx.nd.random_normal(shape=(2,2)).asnumpy()) + >>> print(mx.nd.random.normal(shape=(2,2)).asnumpy()) [[ 1.09544981 -0.20014545] [-0.20808885 0.2527658 ]] >>> >>> mx.random.seed(128) - >>> print(mx.nd.random_normal(shape=(2,2)).asnumpy()) + >>> print(mx.nd.random.normal(shape=(2,2)).asnumpy()) [[ 0.47400656 -0.75213492] [ 0.20251541 0.95352972]] >>> mx.random.seed(128) - >>> print(mx.nd.random_normal(shape=(2,2)).asnumpy()) + >>> print(mx.nd.random.normal(shape=(2,2)).asnumpy()) [[ 0.47400656 -0.75213492] [ 0.20251541 0.95352972]] """ diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py index d93a230f490d..2694b4e5d2fe 100644 --- a/python/mxnet/symbol/__init__.py +++ b/python/mxnet/symbol/__init__.py @@ -17,7 +17,10 @@ """Symbol API of MXNet.""" -from . import _internal, sparse, op +from . import _internal, contrib, linalg, random, sparse # pylint: disable=wildcard-import, redefined-builtin +from .op import * from .symbol import * -from ..ndarray import _GRAD_REQ_MAP +# pylint: enable=wildcard-import + +__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse'] diff --git a/python/mxnet/symbol/_internal.py b/python/mxnet/symbol/_internal.py index cd6ae41c2a19..25b3b60d7fdb 100644 --- a/python/mxnet/symbol/_internal.py +++ b/python/mxnet/symbol/_internal.py @@ -16,3 +16,4 @@ # under the License. """Symbol namespace used to register internal functions.""" +__all__ = [] diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py new file mode 100644 index 000000000000..f21d144c4b63 --- /dev/null +++ b/python/mxnet/symbol/contrib.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Contrib NDArray API of MXNet.""" +__all__ = [] diff --git a/python/mxnet/symbol/linalg.py b/python/mxnet/symbol/linalg.py new file mode 100644 index 000000000000..0c8e7fd57a58 --- /dev/null +++ b/python/mxnet/symbol/linalg.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Linear Algebra NDArray API of MXNet.""" +__all__ = [] diff --git a/python/mxnet/symbol/op.py b/python/mxnet/symbol/op.py index 82884a5cc6a2..9353a73f2c4b 100644 --- a/python/mxnet/symbol/op.py +++ b/python/mxnet/symbol/op.py @@ -16,14 +16,15 @@ # under the License. """Register backend ops in mxnet.symbol namespace.""" +__all__ = [] import sys as _sys import os as _os import ctypes import numpy as _numpy # pylint: disable=unused-import -from mxnet.base import mx_uint, check_call, _LIB, py_str, OpHandle, c_str -from mxnet.symbol_doc import _build_doc +from ..base import mx_uint, check_call, _LIB, py_str +from ..symbol_doc import _build_doc # Use different version of SymbolBase # When possible, use cython to speedup part of computation. @@ -44,7 +45,7 @@ from .._ctypes.symbol import SymbolBase, _set_symbol_class from .._ctypes.symbol import _symbol_creator -from ..base import _Null +from ..base import _Null, _init_op_module from ..name import NameManager from ..attribute import AttrScope # pylint: enable=unused-import @@ -203,40 +204,4 @@ def %s(%s): return symbol_function -def _init_symbol_module(root_namespace): - """List and add all the atomic symbol functions to current module.""" - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist))) - op_names = [] - for i in range(size.value): - op_names.append(py_str(plist[i])) - - module_obj = _sys.modules["%s.symbol" % root_namespace] - module_sparse = _sys.modules["%s.symbol.sparse" % root_namespace] - module_internal = _sys.modules["%s.symbol._internal" % root_namespace] - module_contrib = _sys.modules["%s.contrib.symbol" % root_namespace] - for name in op_names: - hdl = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - function = _make_atomic_symbol_function(hdl, name) - if function.__name__.startswith('_contrib_'): - function.__name__ = function.__name__[9:] - function.__module__ = 'mxnet.contrib.symbol' - setattr(module_contrib, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) - - # register sparse ops under mxnet.symbol.sparse - if function.__name__.startswith('_sparse_'): - function.__name__ = function.__name__[8:] - function.__module__ = 'mxnet.symbol.sparse' - setattr(module_sparse, function.__name__, function) - - -# Initialize the atomic symbol in startups -_init_symbol_module("mxnet") +_init_op_module('mxnet', 'symbol', _make_atomic_symbol_function) diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py new file mode 100644 index 000000000000..0ec4578ba3bf --- /dev/null +++ b/python/mxnet/symbol/random.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Random distribution generator NDArray API of MXNet.""" +__all__ = [] diff --git a/python/mxnet/symbol/sparse.py b/python/mxnet/symbol/sparse.py index 1d94f2b85bc7..3bfcf4ad3f64 100644 --- a/python/mxnet/symbol/sparse.py +++ b/python/mxnet/symbol/sparse.py @@ -16,3 +16,4 @@ # under the License. """Sparse Symbol API of MXNet.""" +__all__ = [] diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 0038840540c2..cc128b26440f 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -40,13 +40,13 @@ from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from ..ndarray import _ndarray_cls from ..executor import Executor -from . import _internal, reshape, transpose, zeros_like, ones_like, broadcast_axes, broadcast_to -from . import flatten, norm, rint, fix, floor, ceil, split, slice_axis, one_hot, pick, take -from . import trunc, expand_dims, flip, tile, repeat, pad, clip, sign -from . import nansum, prod, nanprod, mean, sort, topk, argsort, argmax, argmin -from . import sum, round, max, min, slice, abs # pylint: disable=redefined-builtin +from . import _internal +from . import op from .op import SymbolBase, _set_symbol_class, AttrScope, _Null # pylint: disable=unused-import +__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", + "pow", "maximum", "minimum", "hypot", "zeros", "ones", "full", "arange"] + class Symbol(SymbolBase): """Symbol is symbolic graph of the mxnet.""" @@ -1497,7 +1497,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, shared_buffer[k] = v # create in_args, arg_grads, and aux_states for the current executor - arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) \ + arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)] grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i])) if arg_grad_handles[i] is not None @@ -1742,7 +1742,7 @@ def reshape(self, *args, **kwargs): The arguments are the same as for :py:func:`reshape`, with this array as data. """ - return reshape(self, *args, **kwargs) + return op.reshape(self, *args, **kwargs) def astype(self, *args, **kwargs): """Convenience fluent method for :py:func:`cast`. @@ -1750,7 +1750,7 @@ def astype(self, *args, **kwargs): The arguments are the same as for :py:func:`cast`, with this array as data. """ - return cast(self, *args, **kwargs) + return op.cast(self, *args, **kwargs) def zeros_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`zeros_like`. @@ -1758,7 +1758,7 @@ def zeros_like(self, *args, **kwargs): The arguments are the same as for :py:func:`zeros_like`, with this array as data. """ - return zeros_like(self, *args, **kwargs) + return op.zeros_like(self, *args, **kwargs) def ones_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`ones_like`. @@ -1766,7 +1766,7 @@ def ones_like(self, *args, **kwargs): The arguments are the same as for :py:func:`ones_like`, with this array as data. """ - return ones_like(self, *args, **kwargs) + return op.ones_like(self, *args, **kwargs) def broadcast_axes(self, *args, **kwargs): """Convenience fluent method for :py:func:`broadcast_axes`. @@ -1774,7 +1774,7 @@ def broadcast_axes(self, *args, **kwargs): The arguments are the same as for :py:func:`broadcast_axes`, with this array as data. """ - return broadcast_axes(self, *args, **kwargs) + return op.broadcast_axes(self, *args, **kwargs) def repeat(self, *args, **kwargs): """Convenience fluent method for :py:func:`repeat`. @@ -1782,7 +1782,7 @@ def repeat(self, *args, **kwargs): The arguments are the same as for :py:func:`repeat`, with this array as data. """ - return repeat(self, *args, **kwargs) + return op.repeat(self, *args, **kwargs) def pad(self, *args, **kwargs): """Convenience fluent method for :py:func:`pad`. @@ -1790,7 +1790,7 @@ def pad(self, *args, **kwargs): The arguments are the same as for :py:func:`pad`, with this array as data. """ - return pad(self, *args, **kwargs) + return op.pad(self, *args, **kwargs) def swapaxes(self, *args, **kwargs): """Convenience fluent method for :py:func:`swapaxes`. @@ -1798,7 +1798,7 @@ def swapaxes(self, *args, **kwargs): The arguments are the same as for :py:func:`swapaxes`, with this array as data. """ - return swapaxes(self, *args, **kwargs) + return op.swapaxes(self, *args, **kwargs) def split(self, *args, **kwargs): """Convenience fluent method for :py:func:`split`. @@ -1806,7 +1806,7 @@ def split(self, *args, **kwargs): The arguments are the same as for :py:func:`split`, with this array as data. """ - return split(self, *args, **kwargs) + return op.split(self, *args, **kwargs) def slice(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice`. @@ -1814,7 +1814,7 @@ def slice(self, *args, **kwargs): The arguments are the same as for :py:func:`slice`, with this array as data. """ - return slice(self, *args, **kwargs) + return op.slice(self, *args, **kwargs) def slice_axis(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice_axis`. @@ -1822,7 +1822,7 @@ def slice_axis(self, *args, **kwargs): The arguments are the same as for :py:func:`slice_axis`, with this array as data. """ - return slice_axis(self, *args, **kwargs) + return op.slice_axis(self, *args, **kwargs) def take(self, *args, **kwargs): """Convenience fluent method for :py:func:`take`. @@ -1830,7 +1830,7 @@ def take(self, *args, **kwargs): The arguments are the same as for :py:func:`take`, with this array as data. """ - return take(self, *args, **kwargs) + return op.take(self, *args, **kwargs) def one_hot(self, *args, **kwargs): """Convenience fluent method for :py:func:`one_hot`. @@ -1838,7 +1838,7 @@ def one_hot(self, *args, **kwargs): The arguments are the same as for :py:func:`one_hot`, with this array as data. """ - return one_hot(self, *args, **kwargs) + return op.one_hot(self, *args, **kwargs) def pick(self, *args, **kwargs): """Convenience fluent method for :py:func:`pick`. @@ -1846,7 +1846,7 @@ def pick(self, *args, **kwargs): The arguments are the same as for :py:func:`pick`, with this array as data. """ - return pick(self, *args, **kwargs) + return op.pick(self, *args, **kwargs) def sort(self, *args, **kwargs): """Convenience fluent method for :py:func:`sort`. @@ -1854,7 +1854,7 @@ def sort(self, *args, **kwargs): The arguments are the same as for :py:func:`sort`, with this array as data. """ - return sort(self, *args, **kwargs) + return op.sort(self, *args, **kwargs) def topk(self, *args, **kwargs): """Convenience fluent method for :py:func:`topk`. @@ -1862,7 +1862,7 @@ def topk(self, *args, **kwargs): The arguments are the same as for :py:func:`topk`, with this array as data. """ - return topk(self, *args, **kwargs) + return op.topk(self, *args, **kwargs) def argsort(self, *args, **kwargs): """Convenience fluent method for :py:func:`argsort`. @@ -1870,7 +1870,7 @@ def argsort(self, *args, **kwargs): The arguments are the same as for :py:func:`argsort`, with this array as data. """ - return argsort(self, *args, **kwargs) + return op.argsort(self, *args, **kwargs) def argmax(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmax`. @@ -1878,7 +1878,7 @@ def argmax(self, *args, **kwargs): The arguments are the same as for :py:func:`argmax`, with this array as data. """ - return argmax(self, *args, **kwargs) + return op.argmax(self, *args, **kwargs) def argmin(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmin`. @@ -1886,7 +1886,7 @@ def argmin(self, *args, **kwargs): The arguments are the same as for :py:func:`argmin`, with this array as data. """ - return argmin(self, *args, **kwargs) + return op.argmin(self, *args, **kwargs) def clip(self, *args, **kwargs): """Convenience fluent method for :py:func:`clip`. @@ -1894,7 +1894,7 @@ def clip(self, *args, **kwargs): The arguments are the same as for :py:func:`clip`, with this array as data. """ - return clip(self, *args, **kwargs) + return op.clip(self, *args, **kwargs) def abs(self, *args, **kwargs): """Convenience fluent method for :py:func:`abs`. @@ -1902,7 +1902,7 @@ def abs(self, *args, **kwargs): The arguments are the same as for :py:func:`abs`, with this array as data. """ - return abs(self, *args, **kwargs) + return op.abs(self, *args, **kwargs) def sign(self, *args, **kwargs): """Convenience fluent method for :py:func:`sign`. @@ -1910,7 +1910,7 @@ def sign(self, *args, **kwargs): The arguments are the same as for :py:func:`sign`, with this array as data. """ - return sign(self, *args, **kwargs) + return op.sign(self, *args, **kwargs) def flatten(self, *args, **kwargs): """Convenience fluent method for :py:func:`flatten`. @@ -1918,7 +1918,7 @@ def flatten(self, *args, **kwargs): The arguments are the same as for :py:func:`flatten`, with this array as data. """ - return flatten(self, *args, **kwargs) + return op.flatten(self, *args, **kwargs) def expand_dims(self, *args, **kwargs): """Convenience fluent method for :py:func:`expand_dims`. @@ -1926,7 +1926,7 @@ def expand_dims(self, *args, **kwargs): The arguments are the same as for :py:func:`expand_dims`, with this array as data. """ - return expand_dims(self, *args, **kwargs) + return op.expand_dims(self, *args, **kwargs) def broadcast_to(self, *args, **kwargs): """Convenience fluent method for :py:func:`broadcast_to`. @@ -1934,7 +1934,7 @@ def broadcast_to(self, *args, **kwargs): The arguments are the same as for :py:func:`broadcast_to`, with this array as data. """ - return broadcast_to(self, *args, **kwargs) + return op.broadcast_to(self, *args, **kwargs) def tile(self, *args, **kwargs): """Convenience fluent method for :py:func:`tile`. @@ -1942,7 +1942,7 @@ def tile(self, *args, **kwargs): The arguments are the same as for :py:func:`tile`, with this array as data. """ - return tile(self, *args, **kwargs) + return op.tile(self, *args, **kwargs) def transpose(self, *args, **kwargs): """Convenience fluent method for :py:func:`transpose`. @@ -1950,7 +1950,7 @@ def transpose(self, *args, **kwargs): The arguments are the same as for :py:func:`transpose`, with this array as data. """ - return transpose(self, *args, **kwargs) + return op.transpose(self, *args, **kwargs) def flip(self, *args, **kwargs): """Convenience fluent method for :py:func:`flip`. @@ -1958,7 +1958,7 @@ def flip(self, *args, **kwargs): The arguments are the same as for :py:func:`flip`, with this array as data. """ - return flip(self, *args, **kwargs) + return op.flip(self, *args, **kwargs) def sum(self, *args, **kwargs): """Convenience fluent method for :py:func:`sum`. @@ -1966,7 +1966,7 @@ def sum(self, *args, **kwargs): The arguments are the same as for :py:func:`sum`, with this array as data. """ - return sum(self, *args, **kwargs) + return op.sum(self, *args, **kwargs) def nansum(self, *args, **kwargs): """Convenience fluent method for :py:func:`nansum`. @@ -1974,7 +1974,7 @@ def nansum(self, *args, **kwargs): The arguments are the same as for :py:func:`nansum`, with this array as data. """ - return nansum(self, *args, **kwargs) + return op.nansum(self, *args, **kwargs) def prod(self, *args, **kwargs): """Convenience fluent method for :py:func:`prod`. @@ -1982,7 +1982,7 @@ def prod(self, *args, **kwargs): The arguments are the same as for :py:func:`prod`, with this array as data. """ - return prod(self, *args, **kwargs) + return op.prod(self, *args, **kwargs) def nanprod(self, *args, **kwargs): """Convenience fluent method for :py:func:`nanprod`. @@ -1990,7 +1990,7 @@ def nanprod(self, *args, **kwargs): The arguments are the same as for :py:func:`nanprod`, with this array as data. """ - return nanprod(self, *args, **kwargs) + return op.nanprod(self, *args, **kwargs) def mean(self, *args, **kwargs): """Convenience fluent method for :py:func:`mean`. @@ -1998,7 +1998,7 @@ def mean(self, *args, **kwargs): The arguments are the same as for :py:func:`mean`, with this array as data. """ - return mean(self, *args, **kwargs) + return op.mean(self, *args, **kwargs) def max(self, *args, **kwargs): """Convenience fluent method for :py:func:`max`. @@ -2006,7 +2006,7 @@ def max(self, *args, **kwargs): The arguments are the same as for :py:func:`max`, with this array as data. """ - return max(self, *args, **kwargs) + return op.max(self, *args, **kwargs) def min(self, *args, **kwargs): """Convenience fluent method for :py:func:`min`. @@ -2014,7 +2014,7 @@ def min(self, *args, **kwargs): The arguments are the same as for :py:func:`min`, with this array as data. """ - return min(self, *args, **kwargs) + return op.min(self, *args, **kwargs) def norm(self, *args, **kwargs): """Convenience fluent method for :py:func:`norm`. @@ -2022,7 +2022,7 @@ def norm(self, *args, **kwargs): The arguments are the same as for :py:func:`norm`, with this array as data. """ - return norm(self, *args, **kwargs) + return op.norm(self, *args, **kwargs) def round(self, *args, **kwargs): """Convenience fluent method for :py:func:`round`. @@ -2030,7 +2030,7 @@ def round(self, *args, **kwargs): The arguments are the same as for :py:func:`round`, with this array as data. """ - return round(self, *args, **kwargs) + return op.round(self, *args, **kwargs) def rint(self, *args, **kwargs): """Convenience fluent method for :py:func:`rint`. @@ -2038,7 +2038,7 @@ def rint(self, *args, **kwargs): The arguments are the same as for :py:func:`rint`, with this array as data. """ - return rint(self, *args, **kwargs) + return op.rint(self, *args, **kwargs) def fix(self, *args, **kwargs): """Convenience fluent method for :py:func:`fix`. @@ -2046,7 +2046,7 @@ def fix(self, *args, **kwargs): The arguments are the same as for :py:func:`fix`, with this array as data. """ - return fix(self, *args, **kwargs) + return op.fix(self, *args, **kwargs) def floor(self, *args, **kwargs): """Convenience fluent method for :py:func:`floor`. @@ -2054,7 +2054,7 @@ def floor(self, *args, **kwargs): The arguments are the same as for :py:func:`floor`, with this array as data. """ - return floor(self, *args, **kwargs) + return op.floor(self, *args, **kwargs) def ceil(self, *args, **kwargs): """Convenience fluent method for :py:func:`ceil`. @@ -2062,7 +2062,7 @@ def ceil(self, *args, **kwargs): The arguments are the same as for :py:func:`ceil`, with this array as data. """ - return ceil(self, *args, **kwargs) + return op.ceil(self, *args, **kwargs) def trunc(self, *args, **kwargs): """Convenience fluent method for :py:func:`trunc`. @@ -2070,7 +2070,7 @@ def trunc(self, *args, **kwargs): The arguments are the same as for :py:func:`trunc`, with this array as data. """ - return trunc(self, *args, **kwargs) + return op.trunc(self, *args, **kwargs) def wait_to_read(self): raise NotImplementedForSymbol(self.wait_to_read, None) diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index 363163cbc697..ea6fdd54b925 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -45,9 +45,10 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialParam); .add_arguments(ParamType::__FIELDS__()) // Add "uniform" alias for backward compatibility -MXNET_OPERATOR_REGISTER_SAMPLE(random_uniform, SampleUniformParam) +MXNET_OPERATOR_REGISTER_SAMPLE(_random_uniform, SampleUniformParam) .add_alias("uniform") .add_alias("_sample_uniform") +.add_alias("random_uniform") .describe(R"code(Draw random samples from a uniform distribution. .. note:: The existing alias ``uniform`` is deprecated. @@ -57,17 +58,18 @@ Samples are uniformly distributed over the half-open interval *[low, high)* Example:: - random_uniform(low=0, high=1, shape=(2,2)) = [[ 0.60276335, 0.85794562], - [ 0.54488319, 0.84725171]] + uniform(low=0, high=1, shape=(2,2)) = [[ 0.60276335, 0.85794562], + [ 0.54488319, 0.84725171]] )code" ADD_FILELINE) .set_attr("FCompute", SampleUniform_) .set_attr("FComputeEx", SampleUniformEx_); // Add "normal" alias for backward compatibility -MXNET_OPERATOR_REGISTER_SAMPLE(random_normal, SampleNormalParam) +MXNET_OPERATOR_REGISTER_SAMPLE(_random_normal, SampleNormalParam) .add_alias("normal") .add_alias("_sample_normal") +.add_alias("random_normal") .describe(R"code(Draw random samples from a normal (Gaussian) distribution. .. note:: The existing alias ``normal`` is deprecated. @@ -76,41 +78,44 @@ Samples are distributed according to a normal distribution parametrized by *loc* Example:: - random_normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478], - [-1.23474145, 1.55807114]] + normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478], + [-1.23474145, 1.55807114]] )code" ADD_FILELINE) .set_attr("FCompute", SampleNormal_) .set_attr("FComputeEx", SampleNormalEx_); -MXNET_OPERATOR_REGISTER_SAMPLE(random_gamma, SampleGammaParam) +MXNET_OPERATOR_REGISTER_SAMPLE(_random_gamma, SampleGammaParam) .add_alias("_sample_gamma") +.add_alias("random_gamma") .describe(R"code(Draw random samples from a gamma distribution. Samples are distributed according to a gamma distribution parametrized by *alpha* (shape) and *beta* (scale). Example:: - random_gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289], - [ 3.91697288, 3.65933681]] + gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289], + [ 3.91697288, 3.65933681]] )code" ADD_FILELINE) .set_attr("FCompute", SampleGamma_) .set_attr("FComputeEx", SampleGammaEx_); -MXNET_OPERATOR_REGISTER_SAMPLE(random_exponential, SampleExponentialParam) +MXNET_OPERATOR_REGISTER_SAMPLE(_random_exponential, SampleExponentialParam) .add_alias("_sample_exponential") +.add_alias("random_exponential") .describe(R"code(Draw random samples from an exponential distribution. Samples are distributed according to an exponential distribution parametrized by *lambda* (rate). Example:: - random_exponential(lam=4, shape=(2,2)) = [[ 0.0097189 , 0.08999364], - [ 0.04146638, 0.31715935]] + exponential(lam=4, shape=(2,2)) = [[ 0.0097189 , 0.08999364], + [ 0.04146638, 0.31715935]] )code" ADD_FILELINE) .set_attr("FCompute", SampleExponential_); -MXNET_OPERATOR_REGISTER_SAMPLE(random_poisson, SamplePoissonParam) +MXNET_OPERATOR_REGISTER_SAMPLE(_random_poisson, SamplePoissonParam) .add_alias("_sample_poisson") +.add_alias("random_poisson") .describe(R"code(Draw random samples from a Poisson distribution. Samples are distributed according to a Poisson distribution parametrized by *lambda* (rate). @@ -118,13 +123,14 @@ Samples will always be returned as a floating point data type. Example:: - random_poisson(lam=4, shape=(2,2)) = [[ 5., 2.], - [ 4., 6.]] + poisson(lam=4, shape=(2,2)) = [[ 5., 2.], + [ 4., 6.]] )code" ADD_FILELINE) .set_attr("FCompute", SamplePoisson_); -MXNET_OPERATOR_REGISTER_SAMPLE(random_negative_binomial, SampleNegBinomialParam) +MXNET_OPERATOR_REGISTER_SAMPLE(_random_negative_binomial, SampleNegBinomialParam) .add_alias("_sample_negbinomial") +.add_alias("random_negative_binomial") .describe(R"code(Draw random samples from a negative binomial distribution. Samples are distributed according to a negative binomial distribution parametrized by @@ -133,13 +139,14 @@ Samples will always be returned as a floating point data type. Example:: - random_negative_binomial(k=3, p=0.4, shape=(2,2)) = [[ 4., 7.], - [ 2., 5.]] + negative_binomial(k=3, p=0.4, shape=(2,2)) = [[ 4., 7.], + [ 2., 5.]] )code" ADD_FILELINE) .set_attr("FCompute", SampleNegBinomial_); -MXNET_OPERATOR_REGISTER_SAMPLE(random_generalized_negative_binomial, SampleGenNegBinomialParam) +MXNET_OPERATOR_REGISTER_SAMPLE(_random_generalized_negative_binomial, SampleGenNegBinomialParam) .add_alias("_sample_gennegbinomial") +.add_alias("random_generalized_negative_binomial") .describe(R"code(Draw random samples from a generalized negative binomial distribution. Samples are distributed according to a generalized negative binomial distribution parametrized by @@ -149,8 +156,8 @@ Samples will always be returned as a floating point data type. Example:: - random_generalized_negative_binomial(mu=2.0, alpha=0.3, shape=(2,2)) = [[ 2., 1.], - [ 6., 4.]] + generalized_negative_binomial(mu=2.0, alpha=0.3, shape=(2,2)) = [[ 2., 1.], + [ 6., 4.]] )code" ADD_FILELINE) .set_attr("FCompute", SampleGenNegBinomial_); diff --git a/src/operator/random/sample_op.cu b/src/operator/random/sample_op.cu index 7bdb9faf334e..a26413d51cb8 100644 --- a/src/operator/random/sample_op.cu +++ b/src/operator/random/sample_op.cu @@ -103,11 +103,11 @@ void SampleNormal_(const nnvm::NodeAttrs& attrs, SampleNormalDnsImpl(attrs, ctx, req[0], &out); } -NNVM_REGISTER_OP(random_uniform) +NNVM_REGISTER_OP(_random_uniform) .set_attr("FCompute", SampleUniform_) .set_attr("FComputeEx", SampleUniformEx_); -NNVM_REGISTER_OP(random_normal) +NNVM_REGISTER_OP(_random_normal) .set_attr("FCompute", SampleNormal_) .set_attr("FComputeEx", SampleNormalEx_); diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 70d4f9b766ad..9b94603b5fdc 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -31,7 +31,8 @@ DMLC_REGISTER_PARAMETER(LaMatrixMacParam); DMLC_REGISTER_PARAMETER(LaMatrixMultParam); DMLC_REGISTER_PARAMETER(LaTriangMatrixMultParam); -NNVM_REGISTER_OP(linalg_gemm) +NNVM_REGISTER_OP(_linalg_gemm) +.add_alias("linalg_gemm") .describe(R"code(Performs general matrix multiplication and accumulation. Input are three tensors *A*, *B*, *C* each of dimension *n >= 2* and each having the same shape on the leading *n-2* dimensions. For every *n-2* dimensional index *i* let @@ -54,14 +55,14 @@ Examples:: A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - linalg_gemm(A, B, C, transpose_b = 1, alpha = 2.0 , beta = 10.0) + gemm(A, B, C, transpose_b = 1, alpha = 2.0 , beta = 10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]] // Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] - linalg_gemm(A, B, C, transpose_b = 1, alpha = 2.0 , beta = 10.0) + gemm(A, B, C, transpose_b = 1, alpha = 2.0 , beta = 10.0) = [[[104.0]], [[0.14]]] )code" ADD_FILELINE) .set_num_inputs(3) @@ -91,7 +92,8 @@ NNVM_REGISTER_OP(_backward_linalg_gemm) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_gemm2) +NNVM_REGISTER_OP(_linalg_gemm2) +.add_alias("linalg_gemm2") .describe(R"code(Performs general matrix multiplication. Input are two tensors *A*, *B* each of dimension *n >= 2* and each having the same shape on the leading *n-2* dimensions. For every *n-2* dimensional index *i* let @@ -113,13 +115,13 @@ Examples:: // Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] - linalg_gemm2(A, B, transpose_b = 1, alpha = 2.0) + gemm2(A, B, transpose_b = 1, alpha = 2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]] // Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] - linalg_gemm2(A, B, transpose_b = 1, alpha = 2.0 ) + gemm2(A, B, transpose_b = 1, alpha = 2.0 ) = [[[4.0]], [[0.04 ]]] )code" ADD_FILELINE) .set_num_inputs(2) @@ -146,7 +148,8 @@ NNVM_REGISTER_OP(_backward_linalg_gemm2) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_potrf) +NNVM_REGISTER_OP(_linalg_potrf) +.add_alias("linalg_potrf") .describe(R"code(Performs Cholesky factorization of a symmetric positive-definite matrix. Input is a tensor *A* of dimension *n >= 2*. For every *n-2* dimensional index *i* let *A*\ :sub:`i`\ be the matrix given by the last *2* dimensions. @@ -169,11 +172,11 @@ Examples:: // Single matrix factorization A = [[4.0, 1.0], [1.0, 4.25]] - linalg_potrf(A) = [[2.0, 0], [0.5, 2.0]] + potrf(A) = [[2.0, 0], [0.5, 2.0]] // Batch matrix factorization A = [[[4.0, 1.0], [1.0, 4.25]], [[16.0, 4.0], [4.0, 17.0]]] - linalg_potrf(A) = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] + potrf(A) = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -198,7 +201,8 @@ NNVM_REGISTER_OP(_backward_linalg_potrf) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_potri) +NNVM_REGISTER_OP(_linalg_potri) +.add_alias("linalg_potri") .describe(R"code(Performs matrix inversion from a Cholesky factorization. Input is a tensor *A* of dimension *n >= 2*. For every *n-2* dimensional index *i* let *A*\ :sub:`i`\ be the matrix given by the last *2* dimensions. @@ -220,11 +224,11 @@ Examples:: // Single matrix inverse A = [[2.0, 0], [0.5, 2.0]] - linalg_potri(A) = [[0.26563, -0.0625], [-0.0625, 0.25]] + potri(A) = [[0.26563, -0.0625], [-0.0625, 0.25]] // Batch matrix inverse A = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] - linalg_potri(A) = [[[0.26563, -0.0625], [-0.0625, 0.25]], + potri(A) = [[[0.26563, -0.0625], [-0.0625, 0.25]], [[0.06641, -0.01562], [-0.01562, 0,0625]]] )code" ADD_FILELINE) .set_num_inputs(1) @@ -247,7 +251,8 @@ NNVM_REGISTER_OP(_backward_linalg_potri) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_trmm) +NNVM_REGISTER_OP(_linalg_trmm) +.add_alias("linalg_trmm") .describe(R"code(Performs multiplication with a triangular matrix. Input are two tensors *A*, *B* each of dimension *n >= 2* and each having the same shape on the leading *n-2* dimensions. For every *n-2* dimensional index *i* let @@ -275,12 +280,12 @@ Examples:: // Single matrix multiply A = [[1.0, 0], [1.0, 1.0]] B = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - linalg_trmm(A, B, alpha = 2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] + trmm(A, B, alpha = 2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] // Batch matrix multiply A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]] - linalg_trmm(A, B, alpha = 2.0 ) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], + trmm(A, B, alpha = 2.0 ) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]] )code" ADD_FILELINE) @@ -310,7 +315,8 @@ NNVM_REGISTER_OP(_backward_linalg_trmm) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_trsm) +NNVM_REGISTER_OP(_linalg_trsm) +.add_alias("linalg_trsm") .describe(R"code(Solves matrix equations involving a triangular matrix. Input are two tensors *A*, *B* each of dimension *n >= 2* and each having the same shape on the leading *n-2* dimensions. For every *n-2* dimensional index *i* let @@ -338,13 +344,13 @@ Examples:: // Single matrix solve A = [[1.0, 0], [1.0, 1.0]] B = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] - linalg_trsm(A, B, alpha = 0.5) = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] + trsm(A, B, alpha = 0.5) = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] // Batch matrix solve A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[4.0, 4.0, 4.0], [8.0, 8.0, 8.0]]] - linalg_trsm(A, B, alpha = 0.5 ) = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + trsm(A, B, alpha = 0.5 ) = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[2.0, 2.0, 2.0 ], [2.0, 2.0, 2.0]]] )code" ADD_FILELINE) .set_num_inputs(2) @@ -373,7 +379,8 @@ NNVM_REGISTER_OP(_backward_linalg_trsm) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_sumlogdiag) +NNVM_REGISTER_OP(_linalg_sumlogdiag) +.add_alias("linalg_sumlogdiag") .describe(R"code(Computes the sum of the logarithms of all diagonal elements in a matrix. Input is a tensor *A* of dimension *n >= 2*. For every *n-2* dimensional index *i* let *A*\ :sub:`i`\ be the matrix given by the last *2* dimensions. @@ -389,11 +396,11 @@ Examples:: // Single matrix reduction A = [[1.0, 1.0], [1.0, 7.0]] - linalg_sumlogdiag(A) = [1.9459] + sumlogdiag(A) = [1.9459] // Batch matrix reduction A = [[[1.0, 1.0], [1.0, 7.0]], [[3.0, 0], [0, 17.0]]] - linalg_sumlogdiag(A) = [1.9459, 3.9318] + sumlogdiag(A) = [1.9459, 3.9318] )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index a89d98fd7f82..e5d5b272c08a 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -27,37 +27,37 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(linalg_gemm) +NNVM_REGISTER_OP(_linalg_gemm) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_gemm) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_gemm2) +NNVM_REGISTER_OP(_linalg_gemm2) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_gemm2) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_trmm) +NNVM_REGISTER_OP(_linalg_trmm) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_trmm) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_trsm) +NNVM_REGISTER_OP(_linalg_trsm) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_trsm) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_sumlogdiag) +NNVM_REGISTER_OP(_linalg_sumlogdiag) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_sumlogdiag) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(linalg_potri) +NNVM_REGISTER_OP(_linalg_potri) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_potri) @@ -65,7 +65,7 @@ NNVM_REGISTER_OP(_backward_linalg_potri) #if MXNET_USE_CUSOLVER == 1 -NNVM_REGISTER_OP(linalg_potrf) +NNVM_REGISTER_OP(_linalg_potrf) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_potrf) diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index d409b9ec6056..dfc4a5aee253 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -247,6 +247,7 @@ will return a new array with shape ``(2,1,3,4)``. .add_arguments(ExpandDimParam::__FIELDS__()); NNVM_REGISTER_OP(slice) +.add_alias("_sparse_slice") .add_alias("crop") .describe(R"code(Slices a contiguous region of the array. diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 04426d4ec2e6..b2f620cb3b61 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -43,7 +43,7 @@ def check_countsketch(in_dim,out_dim,n): - sym = mx.contrib.sym.count_sketch(name='countsketch',out_dim = out_dim) + sym = mx.sym.contrib.count_sketch(name='countsketch',out_dim = out_dim) shape = [(n,in_dim), (1,in_dim),(1,in_dim)] #shape of input x, hash h and hash s arr = [mx.nd.empty(shape[i]) for i in range(3)] @@ -109,7 +109,7 @@ def check_ifft(shape): shape = tuple(lst) shape_old = shape shape = (shape[0],shape[1],shape[2],shape[3]*2) - sym = mx.contrib.sym.ifft(name='ifft', compute_size = 128) + sym = mx.sym.contrib.ifft(name='ifft', compute_size = 128) init = [np.random.normal(size=shape, scale=1.0)] arr_grad = [mx.nd.empty(shape)] ctx_list = [{'ctx': mx.gpu(0),'ifft_data': shape, 'type_dict': {'ifft_data': np.float32}}] @@ -175,7 +175,7 @@ def test_ifft(): check_ifft(shape) def check_fft(shape): - sym = mx.contrib.sym.fft(name='fft', compute_size = 128) + sym = mx.sym.contrib.fft(name='fft', compute_size = 128) if len(shape) == 2: if shape[1]%2 != 0: lst = list(shape) @@ -1143,7 +1143,7 @@ def test_psroipooling_with_type(): 'psroipool_rois': np.array([[0, 10, 22, 161, 173], [0, 20, 15, 154, 160]])} # plain psroipooling - sym = mx.contrib.sym.PSROIPooling(spatial_scale=0.0625, output_dim=2, pooled_size=3, name='psroipool') + sym = mx.sym.contrib.PSROIPooling(spatial_scale=0.0625, output_dim=2, pooled_size=3, name='psroipool') ctx_list = [{'ctx': mx.gpu(0), 'psroipool_data': (1, 18, 14, 14), 'psroipool_rois': (2, 5), @@ -1167,7 +1167,7 @@ def test_deformable_psroipooling_with_type(): 'deformable_psroipool_rois': np.array([[0, 10, 22, 161, 173], [0, 20, 15, 154, 160]])} # deformable psroipooling - sym = mx.contrib.sym.DeformablePSROIPooling(spatial_scale=0.0625, sample_per_part=4, group_size=3, pooled_size=3, + sym = mx.sym.contrib.DeformablePSROIPooling(spatial_scale=0.0625, sample_per_part=4, group_size=3, pooled_size=3, output_dim=2, trans_std=0.1, no_trans=False, name='deformable_psroipool') ctx_list = [{'ctx': mx.gpu(0), @@ -1196,7 +1196,7 @@ def test_deformable_psroipooling_with_type(): def test_deformable_convolution_with_type(): np.random.seed(1234) - sym = mx.contrib.sym.DeformableConvolution(num_filter=3, kernel=(3,3), name='deformable_conv') + sym = mx.sym.contrib.DeformableConvolution(num_filter=3, kernel=(3,3), name='deformable_conv') # since atomicAdd does not support fp16 (which deformable conv uses in backward), we do not test fp16 here ctx_list = [{'ctx': mx.gpu(0), 'deformable_conv_data': (2, 2, 10, 10), @@ -1241,7 +1241,7 @@ def test_deformable_convolution_options(): # 'deformable_offset': (2, 18, 7, 7), # 'type_dict': {'deformable_conv_data': np.float16, 'deformable_offset': np.float16}}, ] - sym = mx.contrib.sym.DeformableConvolution(num_filter=3, kernel=(3,3), pad=(1,1), name='deformable_conv') + sym = mx.sym.contrib.DeformableConvolution(num_filter=3, kernel=(3,3), pad=(1,1), name='deformable_conv') check_consistency(sym, ctx_list) # Stride > 1 @@ -1259,7 +1259,7 @@ def test_deformable_convolution_options(): # 'deformable_conv_offset': (2, 18, 3, 3), # 'type_dict': {'deformable_conv_data': np.float16, 'deformable_offset': np.float16}}, ] - sym = mx.contrib.sym.DeformableConvolution(num_filter=3, kernel=(3,3), stride=(2,2), name='deformable_conv') + sym = mx.sym.contrib.DeformableConvolution(num_filter=3, kernel=(3,3), stride=(2,2), name='deformable_conv') check_consistency(sym, ctx_list) # Dilate > 1 @@ -1277,7 +1277,7 @@ def test_deformable_convolution_options(): # 'deformable_conv_offset': (2, 18, 3, 3), # 'type_dict': {'deformable_conv_data': np.float16, 'deformable_offset': np.float16}}, ] - sym = mx.contrib.sym.DeformableConvolution(num_filter=3, kernel=(3,3), dilate=(2,2), name='deformable_conv') + sym = mx.sym.contrib.DeformableConvolution(num_filter=3, kernel=(3,3), dilate=(2,2), name='deformable_conv') check_consistency(sym, ctx_list) # Deformable group > 1 @@ -1295,7 +1295,7 @@ def test_deformable_convolution_options(): # 'deformable_conv_offset': (2, 36, 5, 5), # 'type_dict': {'deformable_conv_data': np.float16, 'deformable_offset': np.float16}}, ] - sym = mx.contrib.sym.DeformableConvolution(num_filter=4, kernel=(3,3), num_deformable_group=2, + sym = mx.sym.contrib.DeformableConvolution(num_filter=4, kernel=(3,3), num_deformable_group=2, name='deformable_conv') def test_residual_fused(): diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 37bb5626f765..9e1eb66e711e 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -344,9 +344,9 @@ def backward(self, dm, dn): return dx, dy f = func() - x = mx.nd.random_uniform(shape=(10,)) + x = mx.nd.random.uniform(shape=(10,)) x.attach_grad() - y = mx.nd.random_uniform(shape=(10,)) + y = mx.nd.random.uniform(shape=(10,)) y.attach_grad() with record(): m, n = f(x, y) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 726213dd5455..f64e95b75806 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -323,7 +323,7 @@ def check_split_data(x, num_slice, batch_axis, **kwargs): def test_split_data(): - x = mx.nd.random_uniform(shape=(128, 33, 64)) + x = mx.nd.random.uniform(shape=(128, 33, 64)) check_split_data(x, 8, 0) check_split_data(x, 3, 1) diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py index 6fbcf8b3dac8..756979d27aea 100644 --- a/tests/python/unittest/test_gluon_model_zoo.py +++ b/tests/python/unittest/test_gluon_model_zoo.py @@ -45,7 +45,7 @@ def test_concurrent(): def test_identity(): model = Identity() - x = mx.nd.random_uniform(shape=(128, 33, 64)) + x = mx.nd.random.uniform(shape=(128, 33, 64)) mx.test_utils.assert_almost_equal(model(x).asnumpy(), x.asnumpy()) @@ -68,7 +68,7 @@ def test_models(): print(model) if not test_pretrain: model.collect_params().initialize() - model(mx.nd.random_uniform(shape=data_shape)).wait_to_read() + model(mx.nd.random.uniform(shape=data_shape)).wait_to_read() if __name__ == '__main__': diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 9d8d76f5aa92..8a5fd90a1e61 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -587,8 +587,8 @@ def test_forward_reshape(): mod.init_optimizer(optimizer_params={'learning_rate': 0.01}) # Train with original data shapes - data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1), - mx.nd.random_uniform(5, 15, dshape2)], + data_batch = mx.io.DataBatch(data=[mx.nd.random.uniform(0, 9, dshape1), + mx.nd.random.uniform(5, 15, dshape2)], label=[mx.nd.ones(lshape)]) mod.forward(data_batch) assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class]) @@ -599,8 +599,8 @@ def test_forward_reshape(): dshape1 = (3, 3, 64, 64) dshape2 = (3, 3, 32, 32) lshape = (3,) - data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1), - mx.nd.random_uniform(5, 15, dshape2)], + data_batch = mx.io.DataBatch(data=[mx.nd.random.uniform(0, 9, dshape1), + mx.nd.random.uniform(5, 15, dshape2)], label=[mx.nd.ones(lshape)]) mod.forward(data_batch) assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class]) @@ -610,8 +610,8 @@ def test_forward_reshape(): dshape1 = (20, 3, 64, 64) dshape2 = (20, 3, 32, 32) lshape = (20,) - data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(3, 5, dshape1), - mx.nd.random_uniform(10, 25, dshape2)], + data_batch = mx.io.DataBatch(data=[mx.nd.random.uniform(3, 5, dshape1), + mx.nd.random.uniform(10, 25, dshape2)], label=[mx.nd.ones(lshape)]) mod.forward(data_batch) assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class]) @@ -622,8 +622,8 @@ def test_forward_reshape(): dshape1 = (20, 3, 120, 120) dshape2 = (20, 3, 32, 64) lshape = (20,) - data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1), - mx.nd.random_uniform(5, 15, dshape2)], + data_batch = mx.io.DataBatch(data=[mx.nd.random.uniform(0, 9, dshape1), + mx.nd.random.uniform(5, 15, dshape2)], label=[mx.nd.ones(lshape)]) mod.forward(data_batch) assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class]) @@ -633,8 +633,8 @@ def test_forward_reshape(): dshape1 = (5, 3, 28, 40) dshape2 = (5, 3, 24, 16) lshape = (5,) - data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1), - mx.nd.random_uniform(15, 25, dshape2)], + data_batch = mx.io.DataBatch(data=[mx.nd.random.uniform(0, 9, dshape1), + mx.nd.random.uniform(15, 25, dshape2)], label=[mx.nd.ones(lshape)]) mod.forward(data_batch) assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class]) @@ -646,8 +646,8 @@ def test_forward_reshape(): dataset_shape2 = (30, 3, 20, 40) labelset_shape = (30,) - eval_dataiter = mx.io.NDArrayIter(data=[mx.nd.random_uniform(0, 9, dataset_shape1), - mx.nd.random_uniform(15, 25, dataset_shape2)], + eval_dataiter = mx.io.NDArrayIter(data=[mx.nd.random.uniform(0, 9, dataset_shape1), + mx.nd.random.uniform(15, 25, dataset_shape2)], label=[mx.nd.ones(labelset_shape)], batch_size=5) assert len(mod.score(eval_data=eval_dataiter, eval_metric='acc')) == 1 @@ -658,8 +658,8 @@ def test_forward_reshape(): dataset_shape1 = (10, 3, 30, 30) dataset_shape2 = (10, 3, 20, 40) - pred_dataiter = mx.io.NDArrayIter(data=[mx.nd.random_uniform(0, 9, dataset_shape1), - mx.nd.random_uniform(15, 25, dataset_shape2)]) + pred_dataiter = mx.io.NDArrayIter(data=[mx.nd.random.uniform(0, 9, dataset_shape1), + mx.nd.random.uniform(15, 25, dataset_shape2)]) mod.bind(data_shapes=[('data1', dshape1), ('data2', dshape2)], for_training=False, force_rebind=True) assert mod.predict(pred_dataiter).shape == tuple([10, num_class]) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index f2c6a834cbcd..7d11dbe9fd62 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -284,7 +284,7 @@ def test_ndarray_slice(): assert same(A[3:8].asnumpy(), A2[3:8]) shape = (3,4,5,6,7) - A = mx.nd.random_uniform(shape=shape) + A = mx.nd.random.uniform(shape=shape) A2 = A.asnumpy() assert same(A[1,3:4,:,1:5].asnumpy(), A2[1,3:4,:,1:5]) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ceb11ed07c02..727312b0100f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1260,7 +1260,7 @@ def test_bpow(a, b): c = mx.sym.broadcast_power(a, b) check_binary_op_forward(c, lambda a, b: a ** b, gen_broadcast_data, mx_nd_func=mx.nd.power) check_binary_op_backward(c, lambda g_out, a, b: (g_out * a **(b - 1) * b, - g_out * a ** b * np.log(a)), gen_broadcast_data) + g_out * a ** b * np.log(a)), gen_broadcast_data) def test_bequal(a, b): c = mx.sym.broadcast_equal(a, b) @@ -3354,7 +3354,7 @@ def test_pick_helper(index_type=np.int32): def check_ctc_loss(acts, labels, loss_truth): in_var = mx.sym.Variable('input') labels_var = mx.sym.Variable('labels') - ctc = mx.contrib.sym.ctc_loss(in_var, labels_var) + ctc = mx.sym.contrib.ctc_loss(in_var, labels_var) acts_nd = mx.nd.array(acts, ctx=default_context()) labels_nd = mx.nd.array(labels, ctx=default_context()) exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd]) @@ -3397,8 +3397,8 @@ def test_quantization_op(): min0 = mx.nd.array([0.0]) max0 = mx.nd.array([1.0]) a = mx.nd.array([[0.1392, 0.5928], [0.6027, 0.8579]]) - qa, min1, max1 = mx.contrib.nd.quantize(a, min0, max0, out_type='uint8') - a_ = mx.contrib.nd.dequantize(qa, min1, max1, out_type='float32') + qa, min1, max1 = mx.nd.contrib.quantize(a, min0, max0, out_type='uint8') + a_ = mx.nd.contrib.dequantize(qa, min1, max1, out_type='float32') qa_real = mx.nd.array([[35, 151], [154, 219]]) a_real = mx.nd.array([[0.13725491, 0.59215689], [0.60392159, 0.8588236]]) @@ -3480,7 +3480,7 @@ def test_psroipooling(): im_data_var = mx.symbol.Variable(name="im_data") rois_data_var = mx.symbol.Variable(name="rois_data") - op = mx.contrib.sym.PSROIPooling(data=im_data_var, rois=rois_data_var, spatial_scale=spatial_scale, + op = mx.sym.contrib.PSROIPooling(data=im_data_var, rois=rois_data_var, spatial_scale=spatial_scale, group_size=num_group, pooled_size=num_group, output_dim=num_classes, name='test_op') rtol, atol = 1e-2, 1e-3 @@ -3510,7 +3510,7 @@ def test_deformable_convolution(): offset_data_var = mx.symbol.Variable(name="offset_data") weight_var = mx.symbol.Variable(name="weight") bias_var = mx.symbol.Variable(name="bias") - op = mx.contrib.sym.DeformableConvolution(name='test_op', data=im_data_var, + op = mx.sym.contrib.DeformableConvolution(name='test_op', data=im_data_var, offset=offset_data_var, weight=weight_var, bias=bias_var, num_filter=num_channel_data, pad=dilate, @@ -3544,7 +3544,7 @@ def test_deformable_psroipooling(): im_data_var = mx.symbol.Variable(name="im_data") rois_data_var = mx.symbol.Variable(name="rois_data") offset_data_var = mx.symbol.Variable(name="offset_data") - op = mx.contrib.sym.DeformablePSROIPooling(data=im_data_var, rois=rois_data_var, + op = mx.sym.contrib.DeformablePSROIPooling(data=im_data_var, rois=rois_data_var, trans=offset_data_var, spatial_scale=spatial_scale, sample_per_part=4, group_size=num_group, pooled_size=num_group, output_dim=num_classes, @@ -3585,22 +3585,22 @@ def test_laop(): data_in1_t = np.transpose(data_in1) data_in2_t = np.transpose(data_in2) res_gemm = 4*np.dot(data_in1,data_in2)+7*data_in4 - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7) + test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7) check_symbolic_forward(test_gemm, [data_in1, data_in2, data_in4], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in2, data_in4], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1_t,data_in2_t)+7*data_in3 - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1, transpose_b = 1) + test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1, transpose_b = 1) check_symbolic_forward(test_gemm, [data_in1, data_in2, data_in3], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in2, data_in3], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1_t,data_in1)+7*data_in3 - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1) + test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1) check_symbolic_forward(test_gemm, [data_in1, data_in1, data_in3], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in1, data_in3], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1,data_in1_t)+7*data_in4 - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_b = 1) + test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_b = 1) check_symbolic_forward(test_gemm, [data_in1, data_in1, data_in4], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in1, data_in4], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) @@ -3615,29 +3615,29 @@ def test_laop(): r = 4*np.dot(data_in1,data_in2)+7*data_in4 r = np.tile(r.flatten(),3) r = np.reshape(r,(3,1,2,2)) - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7) + test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha = 4, beta = 7) check_symbolic_forward(test_gemm, [a, b, c], [r]) if grad_check == 1: check_numeric_gradient(test_gemm, [a, b, c], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) # Check gemm2 operator same way as gemm. res_gemm = 4*np.dot(data_in1,data_in2) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4) + test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4) check_symbolic_forward(test_gemm, [data_in1, data_in2], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in2], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1_t, data_in2_t) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_a = 1, transpose_b = 1) + test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4, transpose_a = 1, transpose_b = 1) check_symbolic_forward(test_gemm, [data_in1, data_in2], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in2], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1_t,data_in1) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_a = 1) + test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4, transpose_a = 1) check_symbolic_forward(test_gemm, [data_in1, data_in1], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in1], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1,data_in1_t) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_b = 1) + test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4, transpose_b = 1) check_symbolic_forward(test_gemm, [data_in1, data_in1], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in1], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) @@ -3650,7 +3650,7 @@ def test_laop(): r = 4*np.dot(data_in1,data_in2) r = np.tile(r.flatten(),3) r = np.reshape(r,(3,1,2,2)) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4) + test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha = 4) check_symbolic_forward(test_gemm, [a, b], [r]) if grad_check == 1: check_numeric_gradient(test_gemm, [a, b], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) @@ -3662,32 +3662,32 @@ def test_laop(): data_in = np.random.uniform(1, 10, shape) # test potrf res_potrf = np.sqrt(data_in) - test_potrf = mx.sym.linalg_potrf(data1) + test_potrf = mx.sym.linalg.potrf(data1) check_symbolic_forward(test_potrf, [data_in], [res_potrf]) if grad_check == 1: check_numeric_gradient(test_potrf, [data_in]) # test potri ones = mx.nd.ones(shape).asnumpy() res_potri = np.divide(ones,data_in*data_in) - test_potri = mx.sym.linalg_potri(data1) + test_potri = mx.sym.linalg.potri(data1) check_symbolic_forward(test_potri, [data_in], [res_potri]) if grad_check == 1: check_numeric_gradient(test_potri, [data_in], atol = 0.01, rtol = 1.5) # test trsm trian_in = data_in *7 - test_trsm = mx.sym.linalg_trsm(data1,data2,alpha = 7) + test_trsm = mx.sym.linalg.trsm(data1,data2,alpha = 7) check_symbolic_forward(test_trsm, [trian_in,data_in], [ones]) if grad_check == 1: check_numeric_gradient(test_trsm, [trian_in,data_in], atol = 0.02, rtol = 2.0) # test trmm trian_in = np.divide(ones,trian_in) - test_trmm = mx.sym.linalg_trmm(data1,data2,alpha = 7, transpose = 1, rightside = 1) + test_trmm = mx.sym.linalg.trmm(data1,data2,alpha = 7, transpose = 1, rightside = 1) check_symbolic_forward(test_trmm, [trian_in,data_in], [ones]) if grad_check == 1: check_numeric_gradient(test_trmm, [trian_in,data_in], atol = 0.02, rtol = 2.0) # test sumlogdiag res_sumlogdiag = np.reshape(np.log(data_in),(4,4)) - test_sumlogdiag = mx.sym.linalg_sumlogdiag(data1) + test_sumlogdiag = mx.sym.linalg.sumlogdiag(data1) check_symbolic_forward(test_sumlogdiag, [data_in], [res_sumlogdiag]) if grad_check == 1: check_numeric_gradient(test_sumlogdiag, [data_in], atol = 0.01, rtol = 2.0) @@ -3733,14 +3733,14 @@ def test_laop(): if grad_check == 1: check_numeric_gradient(test_trsm, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - test_trsm2 = mx.sym.linalg_trsm(data1,data2,alpha = -2, rightside = 1, transpose = 1) + test_trsm2 = mx.sym.linalg.trsm(data1,data2,alpha = -2, rightside = 1, transpose = 1) r = -2*np.reshape(np.array(trian),(4,4)) r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) check_symbolic_forward(test_trsm2, [a,b], [r]) if grad_check == 1: check_numeric_gradient(test_trsm2, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - test_trsm3 = mx.sym.linalg_trsm(data1,data2,alpha = 0.50, transpose = 1) + test_trsm3 = mx.sym.linalg.trsm(data1,data2,alpha = 0.50, transpose = 1) b = np.transpose(np.reshape(np.array(trian),(4,4))) b = np.reshape(np.tile(np.reshape(b,(16)),3),(3,1,4,4)) r = 0.5*np.reshape(np.array(ident),(4,4)) @@ -3749,7 +3749,7 @@ def test_laop(): if grad_check == 1: check_numeric_gradient(test_trsm3, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - test_trsm4 = mx.sym.linalg_trsm(data1,data2,alpha = -0.5, rightside = 1) + test_trsm4 = mx.sym.linalg.trsm(data1,data2,alpha = -0.5, rightside = 1) b = np.tile(np.array(trian),3) b = np.reshape(b,(3,1,4,4)) r = -0.5*np.reshape(np.array(ident),(4,4)) @@ -3769,21 +3769,21 @@ def test_laop(): if grad_check == 1: check_numeric_gradient(test_trmm, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - test_trmm2 = mx.sym.linalg_trmm(data1,data2,alpha = -2) + test_trmm2 = mx.sym.linalg.trmm(data1,data2,alpha = -2) r = -2*np.dot(np.reshape(np.array(trian),(4,4)),np.reshape(np.array(matrix),(4,4))) r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) check_symbolic_forward(test_trmm2, [a,b], [r]) if grad_check == 1: check_numeric_gradient(test_trmm2, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - test_trmm3 = mx.sym.linalg_trmm(data1,data2,rightside = 1) + test_trmm3 = mx.sym.linalg.trmm(data1,data2,rightside = 1) r = np.dot(np.reshape(np.array(matrix),(4,4)),np.reshape(np.array(trian),(4,4))) r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) check_symbolic_forward(test_trmm3, [a,b], [r]) if grad_check == 1: check_numeric_gradient(test_trmm3, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - test_trmm4 = mx.sym.linalg_trmm(data1,data2,alpha = 1.2,transpose = 1) + test_trmm4 = mx.sym.linalg.trmm(data1,data2,alpha = 1.2,transpose = 1) r = 1.2*np.dot(np.transpose(np.reshape(np.array(trian),(4,4))),np.reshape(np.array(matrix),(4,4))) r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4)) check_symbolic_forward(test_trmm4, [a,b], [r]) diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 6b8311c145f5..01c8b0aa06c6 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -27,7 +27,7 @@ def check_with_device(device, dtype): symbols = [ { 'name': 'normal', - 'symbol': mx.sym.random_normal, + 'symbol': mx.sym.random.normal, 'multisymbol': mx.sym.sample_normal, 'ndop': mx.random.normal, 'params': { 'loc': 10.0, 'scale': 0.5 }, @@ -39,7 +39,7 @@ def check_with_device(device, dtype): }, { 'name': 'uniform', - 'symbol': mx.sym.random_uniform, + 'symbol': mx.sym.random.uniform, 'multisymbol': mx.sym.sample_uniform, 'ndop': mx.random.uniform, 'params': { 'low': -1.5, 'high': 3.0 }, @@ -54,7 +54,7 @@ def check_with_device(device, dtype): symbols.extend([ { 'name': 'gamma', - 'symbol': mx.sym.random_gamma, + 'symbol': mx.sym.random.gamma, 'multisymbol': mx.sym.sample_gamma, 'ndop': mx.random.gamma, 'params': { 'alpha': 9.0, 'beta': 0.5 }, @@ -66,7 +66,7 @@ def check_with_device(device, dtype): }, { 'name': 'exponential', - 'symbol': mx.sym.random_exponential, + 'symbol': mx.sym.random.exponential, 'multisymbol': mx.sym.sample_exponential, 'ndop': mx.random.exponential, 'params': { 'lam': 4.0 }, @@ -78,7 +78,7 @@ def check_with_device(device, dtype): }, { 'name': 'poisson', - 'symbol': mx.sym.random_poisson, + 'symbol': mx.sym.random.poisson, 'ndop': mx.random.poisson, 'multisymbol': mx.sym.sample_poisson, 'params': { 'lam': 4.0 }, @@ -90,7 +90,7 @@ def check_with_device(device, dtype): }, { 'name': 'neg-binomial', - 'symbol': mx.sym.random_negative_binomial, + 'symbol': mx.sym.random.negative_binomial, 'multisymbol': mx.sym.sample_negative_binomial, 'ndop': mx.random.negative_binomial, 'params': { 'k': 3, 'p': 0.4 }, @@ -102,7 +102,7 @@ def check_with_device(device, dtype): }, { 'name': 'gen-neg-binomial', - 'symbol': mx.sym.random_generalized_negative_binomial, + 'symbol': mx.sym.random.generalized_negative_binomial, 'multisymbol': mx.sym.sample_generalized_negative_binomial, 'ndop': mx.random.generalized_negative_binomial, 'params': { 'mu': 2.0, 'alpha': 0.3 }, diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index a77343436945..35d9713dbbf4 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -348,7 +348,7 @@ def test_sparse_nd_transpose(): def test_sparse_nd_output_fallback(): shape = (10, 10) out = mx.nd.zeros(shape=shape, stype='row_sparse') - mx.nd.random_normal(shape=shape, out=out) + mx.nd.random.normal(shape=shape, out=out) assert(np.sum(out.asnumpy()) != 0) def test_sparse_nd_random(): @@ -357,7 +357,7 @@ def test_sparse_nd_random(): if default_context().device_type is 'gpu': return shape = (100, 100) - fns = [mx.nd.random_uniform, mx.nd.random_normal, mx.nd.random_gamma] + fns = [mx.nd.random.uniform, mx.nd.random.normal, mx.nd.random.gamma] for fn in fns: rsp_out = mx.nd.zeros(shape=shape, stype='row_sparse') dns_out = mx.nd.zeros(shape=shape, stype='default')