Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Refactor random linalg contrib namespaces (#7604)
Browse files Browse the repository at this point in the history
* Refactor namespaces contrib, linalg, random, and sparse for op registration

Change examples in documentation

Change namespace usage in examples

Fix pylint

Remove unused import

Switch name and alias in linalg and random

Change stype comparison from string to int for functions used internally

Change documentation to use the right namespace

Register ops under ndarray/op.py and symbol/op.py

Remove unused import

Change .cu op names

* Add __all__ to ndarray and symbol modules

* Revert "Add __all__ to ndarray and symbol modules"

This reverts commit 8bc5de7.

* Add __all__ to ndarray and symbol modules
  • Loading branch information
reminisce authored and piiswrong committed Aug 29, 2017
1 parent 5591f42 commit 94a2c60
Show file tree
Hide file tree
Showing 51 changed files with 579 additions and 428 deletions.
2 changes: 1 addition & 1 deletion benchmark/python/sparse/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def run_benchmark(mini_path):
weight_row_dim = batch_size if transpose else feature_dim
weight_shape = (weight_row_dim, output_dim)
if not rsp:
weight = mx.nd.random_uniform(low=0, high=1, shape=weight_shape)
weight = mx.nd.random.uniform(low=0, high=1, shape=weight_shape)
else:
weight = rand_ndarray(weight_shape, "row_sparse", density=0.05, distribution="uniform")
total_cost = {}
Expand Down
6 changes: 3 additions & 3 deletions benchmark/python/sparse/sparse_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_iter(path, data_shape, batch_size):
# model
data_shape = (k, )
train_iter = get_iter(mini_path, data_shape, batch_size)
weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m))
weight = mx.nd.random.uniform(low=0, high=1, shape=(k, m))

csr_data = []
dns_data = []
Expand Down Expand Up @@ -154,7 +154,7 @@ def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs):

def bench_dot_forward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx)
dns = mx.nd.random.uniform(shape=(k, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.tostype('default')
Expand Down Expand Up @@ -183,7 +183,7 @@ def bench_dot_forward(m, k, n, density, ctx, repeat):

def bench_dot_backward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx)
dns = mx.nd.random.uniform(shape=(m, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.tostype('default')
Expand Down
2 changes: 1 addition & 1 deletion example/ctc/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def lstm_unroll(num_lstm_layer, seq_len,
pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0))

loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label)
loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label)
ctc_loss = mx.sym.MakeLoss(loss)

softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc)
Expand Down
2 changes: 1 addition & 1 deletion example/gluon/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def transformer(data, label):
###########################
# train with real_t
data = data.as_in_context(ctx)
noise = mx.nd.random_normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx)
noise = mx.nd.random.normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx)

with autograd.record():
output = netD(data)
Expand Down
6 changes: 3 additions & 3 deletions example/gluon/lstm_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def __init__(self, vocab_size, tag2idx, embedding_dim, hidden_dim):

# Matrix of transition parameters. Entry i,j is the score of
# transitioning *to* i *from* j.
self.transitions = nd.random_normal(shape=(self.tagset_size, self.tagset_size))
self.transitions = nd.random.normal(shape=(self.tagset_size, self.tagset_size))

self.hidden = self.init_hidden()

def init_hidden(self):
return [nd.random_normal(shape=(2, 1, self.hidden_dim // 2)),
nd.random_normal(shape=(2, 1, self.hidden_dim // 2))]
return [nd.random.normal(shape=(2, 1, self.hidden_dim // 2)),
nd.random.normal(shape=(2, 1, self.hidden_dim // 2))]

def _forward_alg(self, feats):
# Do the forward algorithm to compute the partition function
Expand Down
4 changes: 2 additions & 2 deletions example/rcnn/rcnn/symbol/symbol_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_resnet_train(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCH
rpn_cls_act_reshape = mx.symbol.Reshape(
data=rpn_cls_act, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_act_reshape')
if config.TRAIN.CXX_PROPOSAL:
rois = mx.contrib.symbol.Proposal(
rois = mx.symbol.contrib.Proposal(
cls_prob=rpn_cls_act_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois',
feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS),
rpn_pre_nms_top_n=config.TRAIN.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TRAIN.RPN_POST_NMS_TOP_N,
Expand Down Expand Up @@ -189,7 +189,7 @@ def get_resnet_test(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHO
rpn_cls_prob_reshape = mx.symbol.Reshape(
data=rpn_cls_prob, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_prob_reshape')
if config.TEST.CXX_PROPOSAL:
rois = mx.contrib.symbol.Proposal(
rois = mx.symbol.contrib.Proposal(
cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois',
feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS),
rpn_pre_nms_top_n=config.TEST.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TEST.RPN_POST_NMS_TOP_N,
Expand Down
6 changes: 3 additions & 3 deletions example/rcnn/rcnn/symbol/symbol_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_vgg_rpn_test(num_anchors=config.NUM_ANCHORS):
rpn_cls_prob_reshape = mx.symbol.Reshape(
data=rpn_cls_prob, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_prob_reshape')
if config.TEST.CXX_PROPOSAL:
group = mx.contrib.symbol.Proposal(
group = mx.symbol.contrib.Proposal(
cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois', output_score=True,
feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS),
rpn_pre_nms_top_n=config.TEST.PROPOSAL_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TEST.PROPOSAL_POST_NMS_TOP_N,
Expand Down Expand Up @@ -290,7 +290,7 @@ def get_vgg_test(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)
rpn_cls_prob_reshape = mx.symbol.Reshape(
data=rpn_cls_prob, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_prob_reshape')
if config.TEST.CXX_PROPOSAL:
rois = mx.contrib.symbol.Proposal(
rois = mx.symbol.contrib.Proposal(
cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois',
feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS),
rpn_pre_nms_top_n=config.TEST.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TEST.RPN_POST_NMS_TOP_N,
Expand Down Expand Up @@ -373,7 +373,7 @@ def get_vgg_train(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS
rpn_cls_act_reshape = mx.symbol.Reshape(
data=rpn_cls_act, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_act_reshape')
if config.TRAIN.CXX_PROPOSAL:
rois = mx.contrib.symbol.Proposal(
rois = mx.symbol.contrib.Proposal(
cls_prob=rpn_cls_act_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois',
feature_stride=config.RPN_FEAT_STRIDE, scales=tuple(config.ANCHOR_SCALES), ratios=tuple(config.ANCHOR_RATIOS),
rpn_pre_nms_top_n=config.TRAIN.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=config.TRAIN.RPN_POST_NMS_TOP_N,
Expand Down
5 changes: 3 additions & 2 deletions example/ssd/symbol/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,9 @@ def multibox_layer(from_layers, num_classes, sizes=[.2, .95],
step = (steps[k], steps[k])
else:
step = '(-1.0, -1.0)'
anchors = mx.contrib.symbol.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, \
clip=clip, name="{}_anchors".format(from_name), steps=step)
anchors = mx.symbol.contrib.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str,
clip=clip, name="{}_anchors".format(from_name),
steps=step)
anchors = mx.symbol.Flatten(data=anchors)
anchor_layers.append(anchors)

Expand Down
6 changes: 3 additions & 3 deletions example/ssd/symbol/legacy_vgg16_ssd_300.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False,
num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \
num_channels=num_channels, clip=False, interm_layer=0, steps=steps)

tmp = mx.contrib.symbol.MultiBoxTarget(
tmp = mx.symbol.contrib.MultiBoxTarget(
*[anchor_boxes, label, cls_preds], overlap_threshold=.5, \
ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \
negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2),
Expand All @@ -163,7 +163,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False,

# monitoring training status
cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label")
det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
det = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress,
variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk)
det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out")
Expand Down Expand Up @@ -202,7 +202,7 @@ def get_symbol(num_classes=20, nms_thresh=0.5, force_suppress=False,

cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \
name='cls_prob')
out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
out = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress,
variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk)
return out
6 changes: 3 additions & 3 deletions example/ssd/symbol/legacy_vgg16_ssd_512.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t
num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \
num_channels=num_channels, clip=False, interm_layer=0, steps=steps)

tmp = mx.contrib.symbol.MultiBoxTarget(
tmp = mx.symbol.contrib.MultiBoxTarget(
*[anchor_boxes, label, cls_preds], overlap_threshold=.5, \
ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \
negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2),
Expand All @@ -167,7 +167,7 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t

# monitoring training status
cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label")
det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
det = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress,
variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk)
det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out")
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_symbol(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_topk=40

cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \
name='cls_prob')
out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
out = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress,
variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk)
return out
6 changes: 3 additions & 3 deletions example/ssd/symbol/symbol_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_symbol_train(network, num_classes, from_layers, num_filters, strides, pa
num_classes, sizes=sizes, ratios=ratios, normalization=normalizations, \
num_channels=num_filters, clip=False, interm_layer=0, steps=steps)

tmp = mx.contrib.symbol.MultiBoxTarget(
tmp = mx.symbol.contrib.MultiBoxTarget(
*[anchor_boxes, label, cls_preds], overlap_threshold=.5, \
ignore_label=-1, negative_mining_ratio=3, minimum_negative_samples=0, \
negative_mining_thresh=.5, variances=(0.1, 0.1, 0.2, 0.2),
Expand All @@ -106,7 +106,7 @@ def get_symbol_train(network, num_classes, from_layers, num_filters, strides, pa

# monitoring training status
cls_label = mx.symbol.MakeLoss(data=cls_target, grad_scale=0, name="cls_label")
det = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
det = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress,
variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk)
det = mx.symbol.MakeLoss(data=det, grad_scale=0, name="det_out")
Expand Down Expand Up @@ -177,7 +177,7 @@ def get_symbol(network, num_classes, from_layers, num_filters, sizes, ratios,

cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \
name='cls_prob')
out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
out = mx.symbol.contrib.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress,
variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk)
return out
4 changes: 2 additions & 2 deletions example/ssd/tools/caffe_converter/convert_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def proto2script(proto_file):
finput_dim = float(input_dim[2])
step = '(%f, %f)' % (step_h / finput_dim, step_w / finput_dim)
assert param.offset == 0.5, "currently only support offset = 0.5"
symbol_string += '%s = mx.contrib.symbol.MultiBoxPrior(%s, sizes=%s, ratios=%s, clip=%s, steps=%s, name="%s")\n' % \
symbol_string += '%s = mx.symbol.contrib.MultiBoxPrior(%s, sizes=%s, ratios=%s, clip=%s, steps=%s, name="%s")\n' % \
(name, mapping[layer[i].bottom[0]], sizes, ratios_string, clip, step, name)
symbol_string += '%s = mx.symbol.Flatten(data=%s)\n' % (name, name)
type_string = 'split'
Expand All @@ -281,7 +281,7 @@ def proto2script(proto_file):
assert param.share_location == True
assert param.background_label_id == 0
nms_param = param.nms_param
type_string = 'mx.contrib.symbol.MultiBoxDetection'
type_string = 'mx.symbol.contrib.MultiBoxDetection'
param_string = "nms_threshold=%f, nms_topk=%d" % \
(nms_param.nms_threshold, nms_param.top_k)
if type_string == '':
Expand Down
23 changes: 6 additions & 17 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,11 @@
from __future__ import absolute_import as _abs

import ctypes
import sys as _sys
import numpy as np

from ..base import _LIB
from ..base import c_array, py_str, c_str, mx_uint, _Null
from ..base import NDArrayHandle, OpHandle, CachedOpHandle
from ..base import c_array, c_str
from ..base import NDArrayHandle, CachedOpHandle
from ..base import check_call
from ..ndarray_doc import _build_doc


_STORAGE_TYPE_ID_TO_STR = {
-1 : 'undefined',
0 : 'default',
1 : 'row_sparse',
2 : 'csr',
}


class NDArrayBase(object):
Expand Down Expand Up @@ -106,10 +95,10 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
return original_output
if num_output.value == 1:
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
stype=out_stypes[0])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])
stype=out_stypes[i])
for i in range(num_output.value)]


Expand Down Expand Up @@ -160,8 +149,8 @@ def __call__(self, *args, **kwargs):
return original_output
if num_output.value == 1:
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
stype=out_stypes[0])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])
stype=out_stypes[i])
for i in range(num_output.value)]
3 changes: 2 additions & 1 deletion python/mxnet/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from .base import _LIB, check_call, string_types
from .base import mx_uint, NDArrayHandle, c_array, MXCallbackList, SymbolHandle
from .ndarray import NDArray
from .symbol import _GRAD_REQ_MAP, Symbol
from .ndarray import _GRAD_REQ_MAP
from .symbol import Symbol


def set_recording(is_recording): #pylint: disable=redefined-outer-name
Expand Down
83 changes: 83 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def _notify_shutdown():

atexit.register(_notify_shutdown)


def add_fileline_to_docstring(module, incursive=True):
"""Append the definition position to each function contained in module.
Expand Down Expand Up @@ -342,6 +343,7 @@ def _add_fileline(obj):
if inspect.isclass(obj) and incursive:
add_fileline_to_docstring(obj, False)


def _as_list(obj):
"""A utility function that converts the argument to a list if it is not already.
Expand All @@ -359,3 +361,84 @@ def _as_list(obj):
return obj
else:
return [obj]


_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_random_', '_sparse_']


def _get_op_name_prefix(op_name):
"""
Check whether the given op_name starts with any words in `_OP_NAME_PREFIX_LIST`.
If found, return the prefix; else, return an empty string.
"""
for prefix in _OP_NAME_PREFIX_LIST:
if op_name.startswith(prefix):
return prefix
return ""


# pylint: enable=too-many-locals, invalid-name
def _init_op_module(root_namespace, module_name, make_op_func):
"""
Registers op functions created by `make_op_func` under
`root_namespace.module_name.[submodule_name]`,
where `submodule_name` is one of `_OP_SUBMODULE_NAME_LIST`.
Parameters
----------
root_namespace : str
Top level module name, `mxnet` in the current cases.
module_name : str
Second level module name, `ndarray` and `symbol` in the current cases.
make_op_func : str
Function for creating op functions for `ndarray` and `symbol` modules.
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()

check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
op_names.append(py_str(plist[i]))

module_op = sys.modules["%s.%s.op" % (root_namespace, module_name)]
module_internal = sys.modules["%s.%s._internal" % (root_namespace, module_name)]
# contrib module in the old format (deprecated)
# kept here for backward compatibility
# use mx.nd.contrib or mx.sym.contrib from now on
contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
contrib_module_old = sys.modules[contrib_module_name_old]
submodule_dict = {}
for op_name_prefix in _OP_NAME_PREFIX_LIST:
submodule_dict[op_name_prefix] =\
sys.modules["%s.%s.%s" % (root_namespace, module_name, op_name_prefix[1:-1])]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = make_op_func(hdl, name)
op_name_prefix = _get_op_name_prefix(function.__name__)
if len(op_name_prefix) > 0:
# register op under mxnet.module_name.op_name_prefix[1:-1]
# e.g. mxnet.ndarray.sparse.dot, mxnet.symbol.linalg.gemm
function.__name__ = function.__name__[len(op_name_prefix):]
function.__module__ = "%s.%s.%s" % (root_namespace, module_name, op_name_prefix[1:-1])
cur_module = submodule_dict[op_name_prefix]
setattr(cur_module, function.__name__, function)
cur_module.__all__.append(function.__name__)
# if op_name_prefix is '_contrib_', also need to register
# the op under mxnet.contrib.module_name for backward compatibility
if op_name_prefix == '_contrib_':
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = make_op_func(hdl, name)
function.__name__ = function.__name__[len(op_name_prefix):]
function.__module__ = contrib_module_name_old
setattr(contrib_module_old, function.__name__, function)
contrib_module_old.__all__.append(function.__name__)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
module_internal.__all__.append(function.__name__)
else:
setattr(module_op, function.__name__, function)
module_op.__all__.append(function.__name__)
Loading

0 comments on commit 94a2c60

Please sign in to comment.