diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4508a51e..d2efdf58 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -48,6 +48,8 @@ typedef void *NDArrayHandle; typedef const void *FunctionHandle; /*! \brief handle to a function that takes param and creates symbol */ typedef void *AtomicSymbolCreator; +/*! \brief handle to cached operator */ +typedef void *CachedOpHandle; /*! \brief handle to a symbol that can be bind as operator */ typedef void *SymbolHandle; /*! \brief handle to a AtomicSymbol */ @@ -414,6 +416,26 @@ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id); +/*! + * \brief detach and ndarray from computation graph by clearing entry_ + * \param handle NDArray handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out); +/*! + * \brief set the flag for gradient array state. + * \param handle NDArray handle + * \param state the new state. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArraySetGradState(NDArrayHandle handle, int state); +/*! + * \brief set the flag for gradient array state. + * \param handle NDArray handle + * \param state the new state. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out); //-------------------------------- // Part 2: functions on NDArray //-------------------------------- @@ -548,6 +570,39 @@ MXNET_DLL int MXAutogradMarkVariables(mx_uint num_var, */ MXNET_DLL int MXAutogradComputeGradient(mx_uint num_output, NDArrayHandle* output_handles); +/*! + * \brief compute the gradient of outputs w.r.t variabels + * \param num_output number of output NDArray + * \param output_handles output NDArrays + * \param ograd_handles head gradient for NDArrays + * \param retain_graph whether to keep the graph after backward + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXAutogradBackward(mx_uint num_output, + NDArrayHandle* output_handles, + NDArrayHandle* ograd_handles, + int retain_graph); +/*! + * \brief create cached operator + */ +MXNET_DLL int MXCachedCreateOp(AtomicSymbolCreator creator, + int num_inputs, + int num_params, + const char **param_keys, + const char **param_vals, + CachedOpHandle *out); +/*! + * \brief free cached operator + */ +MXNET_DLL int MXCachedFree(CachedOpHandle handle); +/*! + * \brief invoke cached operator + */ +MXNET_DLL int MXCachedInvoke(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs); //-------------------------------------------- // Part 3: symbolic configuration generation //-------------------------------------------- @@ -615,6 +670,19 @@ MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, const char **keys, const char **vals, SymbolHandle *out); +/*! + * \brief Create an AtomicSymbol from cached op. + * \param handle cached node attribute. + * \param name name of new symbol. + * \param num_args the number of symbol arguments + * \param args symbol arguments + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXCachedCreateSymbol(CachedOpHandle handle, + const char* name, + mx_uint num_args, + SymbolHandle* args, + SymbolHandle* out); /*! * \brief Create a Variable Symbol. * \param name name of the variable diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index b8cd5501..f30b09a0 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -47,6 +47,7 @@ class AGNodeEntry { } nnvm::NodeEntry nn_entry() const; + bool is_none() const; }; class AutogradRuntime; @@ -149,6 +150,10 @@ class NDArray { inline bool is_none() const { return ptr_.get() == nullptr; } + /*! \return updated grad state in entry_ */ + bool fresh_out_grad() const; + /*! \return updated grad state in entry_ */ + void set_fresh_out_grad(bool state) const; /*! * \brief Block until all the pending write operations with respect * to current NDArray are finished, and read can be performed. @@ -321,6 +326,14 @@ class NDArray { * \return NDArray in new shape */ NDArray Reshape(const TShape &shape) const; + /*! + * \brief Return a copy of this NDArray without autograd history + */ + NDArray Detach() const { + NDArray ret(*this); + ret.entry_ = autograd::AGNodeEntry{nullptr, 0, 0}; + return ret; + } /*! * \brief Allocate the space if it is delayed allocated. * This is an internal function used by system that normal user should not use diff --git a/nnvm b/nnvm index 93072dc8..7796ac76 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit 93072dc8733aa2a89459ecf16413d96ad0b998db +Subproject commit 7796ac76ccea1fba31afc32056c83f6da38b6c57 diff --git a/python/mxnet/_ctypes/common.py b/python/mxnet/_ctypes/common.py new file mode 100644 index 00000000..5773a6a9 --- /dev/null +++ b/python/mxnet/_ctypes/common.py @@ -0,0 +1,30 @@ +# coding: utf-8 +"""Common code between symbolic and ndarray.""" +from __future__ import absolute_import as _abs + +import ctypes + +from ..base import _LIB +from ..base import c_array, c_str +from ..base import OpHandle, CachedOpHandle +from ..base import check_call + + +class CachedOp(object): + """Cached operator handle.""" + __slots__ = ["handle", "op"] + def __init__(self, op, num_input, **kwargs): + self.op = op + op_handle = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(op), ctypes.byref(op_handle))) + self.handle = CachedOpHandle() + check_call(_LIB.MXCachedCreateOp( + op_handle, + ctypes.c_int(num_input), + ctypes.c_int(len(kwargs)), + c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]), + c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]), + ctypes.byref(self.handle))) + + def __del__(self): + check_call(_LIB.MXCachedFree(self.handle)) diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index 37879e95..a678e172 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -1,7 +1,7 @@ # coding: utf-8 # pylint: disable=invalid-name, protected-access, too-many-arguments # pylint: disable=global-statement, unused-import -"""Symbolic configuration API.""" +"""NDArray configuration API.""" from __future__ import absolute_import as _abs import ctypes @@ -13,6 +13,7 @@ from ..base import NDArrayHandle, OpHandle from ..base import check_call from ..ndarray_doc import _build_doc +from .common import CachedOp class NDArrayBase(object): @@ -78,3 +79,33 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): else: return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) for i in range(num_output.value)] + + +def invoke(cached_op, args, out=None, name=None): # pylint: disable=unused-argument + """ctypes implementation of imperative invoke wrapper""" + if out is not None: + original_output = out + if isinstance(out, NDArrayBase): + out = (out,) + num_output = ctypes.c_int(len(out)) + output_vars = c_array(NDArrayHandle, [i.handle for i in out]) + output_vars = ctypes.cast(output_vars, ctypes.POINTER(NDArrayHandle)) + else: + original_output = None + output_vars = ctypes.POINTER(NDArrayHandle)() + num_output = ctypes.c_int(0) + + check_call(_LIB.MXCachedInvoke( + cached_op.handle, + ctypes.c_int(len(args)), + c_array(NDArrayHandle, [arr.handle for arr in args]), + ctypes.byref(num_output), + ctypes.byref(output_vars))) + + if original_output is not None: + return original_output + if num_output.value == 1: + return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) + else: + return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) + for i in range(num_output.value)] diff --git a/python/mxnet/_ctypes/symbol.py b/python/mxnet/_ctypes/symbol.py index 00d935d4..2ffa1a93 100644 --- a/python/mxnet/_ctypes/symbol.py +++ b/python/mxnet/_ctypes/symbol.py @@ -4,15 +4,12 @@ from __future__ import absolute_import as _abs import ctypes -import sys -import numpy as _numpy from ..base import _LIB -from ..base import c_array, c_str, mx_uint, py_str -from ..base import SymbolHandle, OpHandle +from ..base import c_array, c_str, mx_uint +from ..base import SymbolHandle from ..base import check_call -from ..symbol_doc import _build_doc from ..name import NameManager -from ..attribute import AttrScope +from .common import CachedOp # pylint: disable=unused-import _symbol_cls = None @@ -105,122 +102,38 @@ def _set_symbol_class(cls): _symbol_cls = cls -def _make_atomic_symbol_function(handle, name): - """Create an atomic symbol function by handle and funciton name.""" - real_name = ctypes.c_char_p() - desc = ctypes.c_char_p() - num_args = mx_uint() - arg_names = ctypes.POINTER(ctypes.c_char_p)() - arg_types = ctypes.POINTER(ctypes.c_char_p)() - arg_descs = ctypes.POINTER(ctypes.c_char_p)() - key_var_num_args = ctypes.c_char_p() - ret_type = ctypes.c_char_p() - - check_call(_LIB.MXSymbolGetAtomicSymbolInfo( - handle, ctypes.byref(real_name), ctypes.byref(desc), - ctypes.byref(num_args), - ctypes.byref(arg_names), - ctypes.byref(arg_types), - ctypes.byref(arg_descs), - ctypes.byref(key_var_num_args), - ctypes.byref(ret_type))) - narg = int(num_args.value) - func_name = name - key_var_num_args = py_str(key_var_num_args.value) - ret_type = py_str(ret_type.value) if ret_type.value is not None else '' - doc_str = _build_doc(func_name, - py_str(desc.value), - [py_str(arg_names[i]) for i in range(narg)], - [py_str(arg_types[i]) for i in range(narg)], - [py_str(arg_descs[i]) for i in range(narg)], - key_var_num_args, - ret_type) - - def creator(*args, **kwargs): - """Activation Operator of Neural Net. - The parameters listed below can be passed in as keyword arguments. - - Parameters - ---------- - name : string, required. - Name of the resulting symbol. - - Returns - ------- - symbol: Symbol - the resulting symbol - """ - param_keys = [] - param_vals = [] - symbol_kwargs = {} - - attr = kwargs.pop('attr', None) - kwargs.update(AttrScope.current.get(attr)) - name = kwargs.pop('name', None) - if 'dtype' in kwargs: - kwargs['dtype'] = _numpy.dtype(kwargs['dtype']).name - - if key_var_num_args and key_var_num_args not in kwargs: - param_keys.append(c_str(key_var_num_args)) - param_vals.append(c_str(str(len(args)))) - - for k, v in kwargs.items(): - if isinstance(v, SymbolBase): - symbol_kwargs[k] = v - else: - param_keys.append(c_str(k)) - param_vals.append(c_str(str(v))) - # create atomic symbol - param_keys = c_array(ctypes.c_char_p, param_keys) - param_vals = c_array(ctypes.c_char_p, param_vals) - sym_handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateAtomicSymbol( - handle, - mx_uint(len(param_keys)), - param_keys, param_vals, - ctypes.byref(sym_handle))) - - if len(args) != 0 and len(symbol_kwargs) != 0: - raise TypeError( - '%s can only accept input' - 'Symbols either as positional or keyword arguments, not both' % func_name) - s = _symbol_cls(sym_handle) - - hint = func_name.lower() - name = NameManager.current.get(name, hint) - s._compose(*args, name=name, **symbol_kwargs) - return s - - creator.__name__ = func_name - creator.__doc__ = doc_str - creator.__module__ = 'mxnet.symbol' - return creator - - -def _init_symbol_module(symbol_class, root_namespace): - """List and add all the atomic symbol functions to current module.""" - _set_symbol_class(symbol_class) - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist))) - op_names = [] - for i in range(size.value): - op_names.append(py_str(plist[i])) - - module_obj = sys.modules["%s.symbol" % root_namespace] - module_internal = sys.modules["%s._symbol_internal" % root_namespace] - module_contrib = sys.modules["%s.contrib.symbol" % root_namespace] - for name in op_names: - hdl = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - function = _make_atomic_symbol_function(hdl, name) - if function.__name__.startswith('_contrib_'): - function.__name__ = function.__name__[9:] - function.__module__ = 'mxnet.contrib.symbol' - setattr(module_contrib, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) +def invoke(cached_op, args, name=None): + """Call cached symbolic operator""" + ret = SymbolHandle() + hint = cached_op.op.lower() + name = c_str(NameManager.current.get(name, hint)) + check_call(_LIB.MXCachedCreateSymbol( + cached_op.handle, + name, + mx_uint(len(args)), + c_array(SymbolHandle, [s.handle for s in args]), + ctypes.byref(ret))) + return _symbol_cls(ret) + + +def _symbol_creator(handle, args, kwargs, keys, vals, name): + sym_handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateAtomicSymbol( + ctypes.c_void_p(handle), + mx_uint(len(keys)), + c_array(ctypes.c_char_p, [c_str(i) for i in keys]), + c_array(ctypes.c_char_p, [c_str(str(i)) for i in vals]), + ctypes.byref(sym_handle))) + + if args and kwargs: + raise TypeError( + 'Operators with variable length input can only accept input' + 'Symbols either as positional or keyword arguments, not both') + s = _symbol_cls(sym_handle) + if args: + s._compose(*args, name=name) + elif kwargs: + s._compose(name=name, **kwargs) + else: + s._compose(name=name) + return s diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 83d06e5b..aeb7ef8e 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -59,6 +59,7 @@ def _load_lib(): NDArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p OpHandle = ctypes.c_void_p +CachedOpHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p ExecutorHandle = ctypes.c_void_p DataIterCreatorHandle = ctypes.c_void_p diff --git a/python/mxnet/contrib/autograd.py b/python/mxnet/contrib/autograd.py index 40ab289c..e56361ef 100644 --- a/python/mxnet/contrib/autograd.py +++ b/python/mxnet/contrib/autograd.py @@ -104,24 +104,48 @@ def mark_variables(variables, gradients, grad_reqs='write'): c_array(mx_uint, grad_reqs), c_array(NDArrayHandle, gradient_handles))) -def compute_gradient(outputs): + +def backward(outputs, out_grads=None, retain_graph=False): """Compute the gradients of outputs w.r.t variables. Parameters ---------- outputs: list of NDArray - - Returns - ------- - gradients: list of NDArray + out_grads: list of NDArray or None """ + assert isinstance(outputs, (list, tuple)), \ + "outputs must be a list or tuple of NDArrays" output_handles = [] for arr in outputs: output_handles.append(arr.handle) - check_call(_LIB.MXAutogradComputeGradient( + if out_grads is None: + check_call(_LIB.MXAutogradBackward( + len(output_handles), + c_array(NDArrayHandle, output_handles), + ctypes.c_void_p(0), + ctypes.c_int(retain_graph))) + return + + ograd_handles = [] + for arr in out_grads: + if arr is not None: + ograd_handles.append(arr.handle) + else: + ograd_handles.append(NDArrayHandle(0)) + assert len(ograd_handles) == len(output_handles), \ + "outputs and out_grads must have the same length" + + check_call(_LIB.MXAutogradBackward( len(output_handles), - c_array(NDArrayHandle, output_handles))) + c_array(NDArrayHandle, output_handles), + c_array(NDArrayHandle, ograd_handles), + ctypes.c_int(retain_graph))) + + +def compute_gradient(outputs): + """Deprecated. Please use backward""" + backward(outputs) def grad_and_loss(func, argnum=None): diff --git a/python/mxnet/cython/base.pyi b/python/mxnet/cython/base.pyi index a60aaef3..65125813 100644 --- a/python/mxnet/cython/base.pyi +++ b/python/mxnet/cython/base.pyi @@ -7,6 +7,7 @@ from cpython.version cimport PY_MAJOR_VERSION ctypedef void* SymbolHandle ctypedef void* NDArrayHandle ctypedef void* OpHandle +ctypedef void* CachedOpHandle ctypedef unsigned nn_uint cdef py_str(const char* x): @@ -98,3 +99,75 @@ cdef extern from "mxnet/c_api.h": const char **param_keys, const char **param_vals); int MXNDArrayFree(NDArrayHandle handle); + int MXCachedCreateOp(OpHandle creator, + int num_inputs, + int num_params, + const char **param_keys, + const char **param_vals, + CachedOpHandle *out); + int MXCachedFree(CachedOpHandle handle); + int MXCachedInvoke(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs); + int MXCachedCreateSymbol(CachedOpHandle handle, + const char* name, + unsigned num_args, + SymbolHandle* args, + SymbolHandle* out); + + +cdef class CachedOp: + """Cached operator handle.""" + cdef CachedOpHandle chandle + cdef string cop + + cdef _set_handle(self, handle): + cdef unsigned long long ptr + if handle is None: + self.chandle = NULL + else: + ptr = handle.value + self.chandle = (ptr) + + property handle: + def __get__(self): + if self.chandle == NULL: + return None + else: + return _ctypes.cast(self.chandle, _ctypes.c_void_p) + def __set__(self, value): + self._set_handle(value) + + property op: + def __get__(self): + return py_str(self.cop.c_str()) + def __set__(self, value): + self.cop = c_str(value) + + def __init__(self, op, num_input, **kwargs): + cdef OpHandle op_handle + cdef vector[string] ckeys + cdef vector[string] cvals + + self.op = op + CALL(NNGetOpHandle(self.cop.c_str(), &op_handle)) + + for k, v in kwargs.items(): + ckeys.push_back(c_str(k)) + cvals.push_back(c_str(str(v))) + + cdef vector[const char*] param_keys = SVec2Ptr(ckeys) + cdef vector[const char*] param_vals = SVec2Ptr(cvals) + + CALL(MXCachedCreateOp( + op_handle, + num_input, + len(kwargs), + CBeginPtr(param_keys), + CBeginPtr(param_vals), + &self.chandle)) + + def __del__(self): + CALL(MXCachedFree(self.chandle)) diff --git a/python/mxnet/cython/ndarray.pyx b/python/mxnet/cython/ndarray.pyx index 6071f796..24e37b54 100644 --- a/python/mxnet/cython/ndarray.pyx +++ b/python/mxnet/cython/ndarray.pyx @@ -60,9 +60,52 @@ cdef NewArray(NDArrayHandle handle): (nd).cwritable = True return nd + +def invoke(cached_op, args, out=None, name=None): + """ctypes implementation of imperative invoke wrapper""" + cdef vector[NDArrayHandle] ndvars + cdef vector[NDArrayHandle] output_vars + cdef NDArrayHandle* p_output_vars + cdef NDArrayHandle ret_handle + cdef int num_output + + for i in args: + ndvars.push_back((i).chandle) + + original_output = None + if out is not None: + original_output = out + if isinstance(out, NDArrayBase): + output_vars.push_back((out).chandle) + else: + for i in out: + output_vars.push_back((i).chandle) + + num_output = output_vars.size() + if output_vars.size() == 0: + output_vars.resize(1) + p_output_vars = NULL + else: + p_output_vars = &output_vars[0] + + CALL(MXCachedInvoke( + (cached_op).chandle, + len(args), + &ndvars[0] if ndvars.size() != 0 else NULL, + &num_output, + &p_output_vars)) + + if original_output is not None: + return original_output + if num_output == 1: + return NewArray(p_output_vars[0]) + else: + return tuple(NewArray(p_output_vars[i]) for i in range(num_output)) + + def _imperative_invoke(handle, ndargs, keys, vals, out): """cython implementation of imperative invoke wrapper""" - cdef int64_t ihandle = handle + cdef unsigned long long ihandle = handle cdef OpHandle chandle = ihandle cdef vector[string] ckeys cdef vector[string] cvals diff --git a/python/mxnet/cython/symbol.pyx b/python/mxnet/cython/symbol.pyx index 40184f62..e8787fba 100644 --- a/python/mxnet/cython/symbol.pyx +++ b/python/mxnet/cython/symbol.pyx @@ -68,7 +68,7 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): _symbol_cls = SymbolBase -cdef _set_symbol_class(cls): +def _set_symbol_class(cls): global _symbol_cls _symbol_cls = cls @@ -78,129 +78,68 @@ cdef NewSymbol(SymbolHandle handle): (sym).chandle = handle return sym -cdef _make_atomic_symbol_function(OpHandle handle, string name): - """Create an atomic symbol function by handle and funciton name.""" - cdef const char *real_name - cdef const char *desc - cdef nn_uint num_args - cdef const char** arg_names - cdef const char** arg_types - cdef const char** arg_descs - cdef const char* return_type - cdef const char* key_var_num_args - - CALL(MXSymbolGetAtomicSymbolInfo( - handle, &real_name, &desc, - &num_args, &arg_names, - &arg_types, &arg_descs, - &key_var_num_args, &return_type)) - func_name = py_str(name.c_str()) - - key_vargs = py_str(key_var_num_args) - num_args = int(num_args) - doc_str = _build_doc(func_name, - py_str(desc), - [py_str(arg_names[i]) for i in range(num_args)], - [py_str(arg_types[i]) for i in range(num_args)], - [py_str(arg_descs[i]) for i in range(num_args)], - key_vargs, - py_str(return_type) if return_type != NULL else '') - - func_hint = func_name.lower() - - def creator(*args, **kwargs): - cdef vector[string] sparam_keys - cdef vector[string] sparam_vals - cdef vector[SymbolHandle] symbol_args - cdef vector[string] ssymbol_keys - cdef SymbolHandle ret_handle - attr = kwargs.pop("attr", None) - kwargs.update(AttrScope.current.get(attr)) - name = kwargs.pop("name", None) - - if key_vargs: - if key_vargs not in kwargs: - sparam_keys.push_back(c_str(key_vargs)) - sparam_vals.push_back(c_str(str(len(args)))) - - if len(kwargs) != 0: - for k, v in kwargs.items(): - if isinstance(v, SymbolBase): - ssymbol_keys.push_back(c_str(k)) - symbol_args.push_back((v).chandle) - elif k == 'dtype': - sparam_keys.push_back(c_str(k)) - sparam_vals.push_back(c_str(_numpy.dtype(v).name)) - else: - sparam_keys.push_back(c_str(k)) - sparam_vals.push_back(c_str(str(v))) - - if len(args) != 0: - if symbol_args.size() != 0: - raise TypeError("compose only accept input Symbols\ - either as positional or keyword arguments, not both") - for v in args: - if not isinstance(v, SymbolBase): - raise TypeError('Compose expect `Symbol` as arguments') - symbol_args.push_back((v).chandle) - - cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys) - cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals) - cdef vector[const char*] symbol_keys = SVec2Ptr(ssymbol_keys) - - CALL(MXSymbolCreateAtomicSymbol( - handle, - param_keys.size(), - CBeginPtr(param_keys), - CBeginPtr(param_vals), - &ret_handle)) - num_args = (symbol_args.size()) - - name = NameManager.current.get(name, func_hint) - - cdef const char* c_name = NULL - - if name: - name = c_str(name) - c_name = name - - CALL(NNSymbolCompose( - ret_handle, - c_name, - num_args, - &symbol_keys[0] if symbol_keys.size() != 0 else NULL, - &symbol_args[0] if symbol_args.size() != 0 else NULL)) - return NewSymbol(ret_handle) - - creator.__name__ = func_name - creator.__doc__ = doc_str - creator.__module__ = 'mxnet.symbol' - return creator - - -def _init_symbol_module(symbol_class, root_namespace): - """List and add all the atomic symbol functions to current module.""" - cdef const char** op_name_ptrs - cdef nn_uint size - cdef vector[string] op_names - cdef OpHandle handle - - _set_symbol_class(symbol_class) - CALL(MXListAllOpNames(&size, &op_name_ptrs)) - for i in range(size): - op_names.push_back(string(op_name_ptrs[i])) - - module_obj = _sys.modules["%s.symbol" % root_namespace] - module_internal = _sys.modules["%s._symbol_internal" % root_namespace] - module_contrib = _sys.modules["%s.contrib.symbol" % root_namespace] - for i in range(op_names.size()): - CALL(NNGetOpHandle(op_names[i].c_str(), &handle)) - function = _make_atomic_symbol_function(handle, op_names[i]) - if function.__name__.startswith('_contrib_'): - function.__name__ = function.__name__[9:] - function.__module__ = 'mxnet.contrib.symbol' - setattr(module_contrib, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) + +def invoke(cached_op, args, name=None): + cdef SymbolHandle ret + cdef vector[SymbolHandle] sym_args + hint = cached_op.op.lower() + cdef string cname = c_str(NameManager.current.get(name, hint)) + for i in args: + sym_args.push_back((i).chandle) + CALL(MXCachedCreateSymbol( + (cached_op).chandle, + cname.c_str(), + len(args), + &sym_args[0] if sym_args.size() != 0 else NULL, + &ret)) + return NewSymbol(ret) + + +def _symbol_creator(handle, args, kwargs, keys, vals, name): + cdef unsigned long long ihandle = handle + cdef OpHandle chandle = ihandle + cdef vector[string] ckeys + cdef vector[string] cvals + cdef vector[string] sym_keys + cdef vector[SymbolHandle] sym_args + cdef SymbolHandle ret_handle + cdef string cname = c_str(name) + + for i in keys: + ckeys.push_back(c_str(i)) + for i in vals: + cvals.push_back(c_str(str(i))) + + cdef vector[const char*] param_keys = SVec2Ptr(ckeys) + cdef vector[const char*] param_vals = SVec2Ptr(cvals) + + CALL(MXSymbolCreateAtomicSymbol( + chandle, + param_keys.size(), + CBeginPtr(param_keys), + CBeginPtr(param_vals), + &ret_handle)) + + if args and kwargs: + raise TypeError( + 'Operators with variable length input can only accept input' + 'Symbols either as positional or keyword arguments, not both') + + if args: + for i in args: + sym_args.push_back((i).chandle) + elif kwargs: + for k, v in kwargs.items(): + sym_keys.push_back(c_str(k)) + sym_args.push_back((v).chandle) + + cdef vector[const char*] csym_keys = SVec2Ptr(sym_keys) + + CALL(NNSymbolCompose( + ret_handle, + cname.c_str(), + sym_args.size(), + &csym_keys[0] if csym_keys.size() != 0 else NULL, + &sym_args[0] if sym_args.size() != 0 else NULL)) + + return NewSymbol(ret_handle) diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index f86404eb..c5d67545 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -31,15 +31,19 @@ # pylint: disable=unused-import try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: - from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class + from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke elif _sys.version_info >= (3, 0): from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + from ._cy3.ndarray import invoke, CachedOp, _imperative_invoke else: from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + from ._cy2.ndarray import invoke, CachedOp, _imperative_invoke except ImportError: if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke # pylint: enable=unused-import # pylint: disable= no-member @@ -749,6 +753,24 @@ def T(self): return transpose(self) # pylint: enable= invalid-name, undefined-variable + @property + def _fresh_grad(self): + """Whether this array's corresponding gradient array + (registered via `autograd.mark_variables`) has been + updated by `autograd.backward` since last reset. + + `_fresh_grad` need to be manually set to False + after consuming gradient (usually after updating this + array). + """ + out = ctypes.c_int() + check_call(_LIB.MXNDArrayGetGradState(self.handle, ctypes.byref(out))) + return out.value + + @_fresh_grad.setter + def _fresh_grad(self, state): + check_call(_LIB.MXNDArraySetGradState(self.handle, ctypes.c_int(state))) + def asnumpy(self): """Returns a ``numpy.ndarray`` object with value copied from this array. @@ -910,7 +932,7 @@ def detach(self): check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl))) return NDArray(hdl) - def backward(self, out_grad=None): + def backward(self, out_grad=None, retain_graph=False): """Compute the gradients of this NDArray w.r.t variables. Parameters @@ -924,7 +946,8 @@ def backward(self, out_grad=None): check_call(_LIB.MXAutogradBackward( 1, c_array(NDArrayHandle, [self.handle]), - c_array(NDArrayHandle, ograd_handles))) + c_array(NDArrayHandle, ograd_handles), + ctypes.c_int(retain_graph))) def onehot_encode(indices, out): @@ -2327,7 +2350,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) for i in {}: assert isinstance(i, NDArrayBase), \\ "Positional arguments must have NDArray type, " \\ - "but got %s"%str(type(i)) + "but got %s"%str(i) ndargs.append(i)""".format(arr_name)) if dtype_name is not None: code.append(""" @@ -2335,10 +2358,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) kwargs['%s'] = np.dtype(kwargs['%s']).name"""%( dtype_name, dtype_name, dtype_name)) code.append(""" - try: - kwargs.pop('name') - except: - pass + _ = kwargs.pop('name', None) out = kwargs.pop('out', None) keys = list(kwargs.keys()) vals = list(kwargs.values())""") @@ -2353,7 +2373,7 @@ def %s(%s): code.append(""" if {name} is not None: assert isinstance({name}, NDArrayBase), \\ - "Argument {name} must have NDArray type, but got %s"%str(type({name})) + "Argument {name} must have NDArray type, but got %s"%str({name}) ndargs.append({name})""".format(name=name)) # kwargs for name in kwarg_names: diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index d09de16f..16cbeae3 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -14,27 +14,33 @@ from .base import _LIB, numeric_types from .base import c_array, c_str, mx_uint, py_str, string_types, mx_real_t -from .base import NDArrayHandle, ExecutorHandle, SymbolHandle -from .base import check_call, MXNetError +from .base import NDArrayHandle, ExecutorHandle, SymbolHandle, OpHandle +from .base import check_call, MXNetError, _Null # pylint: disable=unused-import from .context import Context, cpu from .ndarray import NDArray, zeros as _nd_zeros, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .name import NameManager # pylint: disable=unused-import from .executor import Executor from . import _symbol_internal as _internal from .attribute import AttrScope +from .symbol_doc import _build_doc # Use different version of SymbolBase # When possible, use cython to speedup part of computation. try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: - from ._ctypes.symbol import SymbolBase, _init_symbol_module + from ._ctypes.symbol import SymbolBase, _set_symbol_class + from ._ctypes.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import elif _sys.version_info >= (3, 0): - from ._cy3.symbol import SymbolBase, _init_symbol_module + from ._cy3.symbol import SymbolBase, _set_symbol_class + from ._cy3.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import else: - from ._cy2.symbol import SymbolBase, _init_symbol_module + from ._cy2.symbol import SymbolBase, _set_symbol_class + from ._cy2.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import except ImportError: if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") - from ._ctypes.symbol import SymbolBase, _init_symbol_module + from ._ctypes.symbol import SymbolBase, _set_symbol_class + from ._ctypes.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import _GRAD_REQ_MAP = {'null': 0, 'write': 1, 'add': 3} @@ -1651,9 +1657,6 @@ def load_json(json_str): return Symbol(handle) -# Initialize the atomic symbol in startups -_init_symbol_module(Symbol, "mxnet") - # pylint: disable=no-member # pylint: disable=redefined-builtin def pow(base, exp): @@ -1901,3 +1904,189 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None): dtype = _numpy.float32 return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, name=name, dtype=dtype) + + +def _make_atomic_symbol_function(handle, name): + """Create an atomic symbol function by handle and funciton name.""" + real_name = ctypes.c_char_p() + desc = ctypes.c_char_p() + num_args = mx_uint() + arg_names = ctypes.POINTER(ctypes.c_char_p)() + arg_types = ctypes.POINTER(ctypes.c_char_p)() + arg_descs = ctypes.POINTER(ctypes.c_char_p)() + key_var_num_args = ctypes.c_char_p() + ret_type = ctypes.c_char_p() + + check_call(_LIB.MXSymbolGetAtomicSymbolInfo( + handle, ctypes.byref(real_name), ctypes.byref(desc), + ctypes.byref(num_args), + ctypes.byref(arg_names), + ctypes.byref(arg_types), + ctypes.byref(arg_descs), + ctypes.byref(key_var_num_args), + ctypes.byref(ret_type))) + narg = int(num_args.value) + arg_names = [py_str(arg_names[i]) for i in range(narg)] + arg_types = [py_str(arg_types[i]) for i in range(narg)] + func_name = name + key_var_num_args = py_str(key_var_num_args.value) + ret_type = py_str(ret_type.value) if ret_type.value is not None else '' + doc_str = _build_doc(func_name, + py_str(desc.value), + arg_names, + arg_types, + [py_str(arg_descs[i]) for i in range(narg)], + key_var_num_args, + ret_type) + + dtype_name = None + arr_name = None + ndsignature = [] + signature = [] + ndarg_names = [] + kwarg_names = [] + for i in range(narg): + name, atype = arg_names[i], arg_types[i] + if name == 'dtype': + dtype_name = name + signature.append('%s=_Null'%name) + elif atype.startswith('NDArray') or atype.startswith('Symbol'): + assert not arr_name, \ + "Op can only have one argument with variable " \ + "size and it must be the last argument." + if atype.endswith('[]'): + ndsignature.append('*%s'%name) + arr_name = name + else: + ndsignature.append('%s=None'%name) + ndarg_names.append(name) + else: + signature.append('%s=_Null'%name) + kwarg_names.append(name) + #signature.append('is_train=False') + signature.append('name=None') + signature.append('attr=None') + signature.append('out=None') + signature.append('**kwargs') + signature = ndsignature + signature + + code = [] + if arr_name: + code.append(""" +def %s(*%s, **kwargs):"""%(func_name, arr_name)) + code.append(""" + sym_args = [] + for i in {}: + assert isinstance(i, SymbolBase), \\ + "Positional arguments must be Symbol instances, " \\ + "but got %s"%str(i) + sym_args.append(i)""".format(arr_name)) + if dtype_name is not None: + code.append(""" + if '%s' in kwargs: + kwargs['%s'] = _numpy.dtype(kwargs['%s']).name"""%( + dtype_name, dtype_name, dtype_name)) + code.append(""" + attr = kwargs.pop('attr', None) + kwargs.update(AttrScope.current.get(attr)) + name = kwargs.pop('name', None) + name = NameManager.current.get(name, '%s') + _ = kwargs.pop('out', None) + keys = [] + vals = [] + sym_kwargs = dict() + for k, v in kwargs.items(): + if isinstance(v, SymbolBase): + sym_kwargs[k] = v + else: + keys.append(k) + vals.append(v)"""%(func_name.lower())) + if key_var_num_args: + code.append(""" + if '%s' not in kwargs: + keys.append('%s') + vals.append(len(sym_args) + len(sym_kwargs))"""%( + key_var_num_args, key_var_num_args)) + + code.append(""" + return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%( + handle.value)) + else: + code.append(""" +def %s(%s): + kwargs.update(AttrScope.current.get(attr)) + sym_kwargs = dict() + keys = [] + vals = []"""%(func_name, ', '.join(signature))) + code.append(""" + for k, v in kwargs.items(): + if isinstance(v, SymbolBase): + sym_kwargs[k] = v + else: + keys.append(k) + vals.append(v)""") + # NDArray args + for name in ndarg_names: + code.append(""" + if {name} is not None: + assert isinstance({name}, SymbolBase), \\ + "Argument {name} must be Symbol instances, but got %s"%str({name}) + sym_kwargs['{name}'] = {name}""".format(name=name)) + # kwargs + for name in kwarg_names: + code.append(""" + if %s is not _Null: + keys.append('%s') + vals.append(%s)"""%(name, name, name)) + # dtype + if dtype_name is not None: + code.append(""" + if %s is not _Null: + keys.append('%s') + vals.append(_numpy.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) + + code.append(""" + name = NameManager.current.get(name, '%s') + return _symbol_creator(%d, None, sym_kwargs, keys, vals, name)"""%( + func_name.lower(), handle.value)) + + local = {} + exec(''.join(code), None, local) # pylint: disable=exec-used + symbol_function = local[func_name] + symbol_function.__name__ = func_name + symbol_function.__doc__ = doc_str + symbol_function.__module__ = 'mxnet.symbol' + return symbol_function + + +def _init_symbol_module(symbol_class, root_namespace): + """List and add all the atomic symbol functions to current module.""" + _set_symbol_class(symbol_class) + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.MXListAllOpNames(ctypes.byref(size), + ctypes.byref(plist))) + op_names = [] + for i in range(size.value): + op_names.append(py_str(plist[i])) + + module_obj = _sys.modules["%s.symbol" % root_namespace] + module_internal = _sys.modules["%s._symbol_internal" % root_namespace] + module_contrib = _sys.modules["%s.contrib.symbol" % root_namespace] + for name in op_names: + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + function = _make_atomic_symbol_function(hdl, name) + if function.__name__.startswith('_contrib_'): + function.__name__ = function.__name__[9:] + function.__module__ = 'mxnet.contrib.symbol' + setattr(module_contrib, function.__name__, function) + elif function.__name__.startswith('_'): + setattr(module_internal, function.__name__, function) + else: + setattr(module_obj, function.__name__, function) + + +# Initialize the atomic symbol in startups +_init_symbol_module(Symbol, "mxnet") diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 41986a0d..9d60c861 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -398,6 +398,27 @@ int MXNDArrayGetContext(NDArrayHandle handle, API_END(); } +int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->Detach()); + API_END(); +} + +int MXNDArraySetGradState(NDArrayHandle handle, int state) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + arr->set_fresh_out_grad(static_cast(state)); + API_END(); +} + +int MXNDArrayGetGradState(NDArrayHandle handle, int *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = arr->fresh_out_grad(); + API_END(); +} + int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array) { API_BEGIN(); diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 66a237a4..0be1d357 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -87,7 +87,8 @@ void SetNDInputsOutputs(const nnvm::Op* op, ndoutputs.resize(infered_num_outputs); } else { CHECK(!AutogradRuntime::Get()->IsTraining()) - << "Cannot assign to NDArray or specify 'out' when training with autograd"; + << "Inplace operations (+=, -=, op(..., out=x) etc.) and assignment are " + << "not supported when you are inside a train_section using autograd."; CHECK(*num_outputs == infered_num_outputs || *num_outputs == num_visible_outputs) << "Expecting " << infered_num_outputs << " (all) or " << num_visible_outputs << " (visible only) outputs, got " @@ -321,26 +322,18 @@ void PushOperator(std::shared_ptr opr, 0, PROFILER_MESSAGE(op->name.c_str())); } -int MXImperativeInvoke(AtomicSymbolCreator creator, - int num_inputs, - NDArrayHandle *inputs, - int *num_outputs, - NDArrayHandle **outputs, - int num_params, - const char **param_keys, - const char **param_vals) { +void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs) { static auto& fcpu = nnvm::Op::GetAttr("FCompute"); static auto& fgpu = nnvm::Op::GetAttr("FCompute"); static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); static auto& createop = nnvm::Op::GetAttr("FCreateLayerOp"); - const nnvm::Op* op = static_cast(creator); - NDArray** outarray = *reinterpret_cast(outputs); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); - - API_BEGIN(); - nnvm::NodeAttrs attrs; - SetOpAttrs(op, &attrs, - num_inputs, num_params, param_keys, param_vals); + NDArray** outarray = *reinterpret_cast(outputs); + const nnvm::Op *op = attrs.op; int infered_num_outputs; int num_visible_outputs; @@ -408,6 +401,57 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, *outarray[i] = std::move(ndoutputs[i]); } } +} + +int MXImperativeInvoke(AtomicSymbolCreator creator, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + int num_params, + const char **param_keys, + const char **param_vals) { + const nnvm::Op* op = static_cast(creator); + + API_BEGIN(); + nnvm::NodeAttrs attrs; + SetOpAttrs(op, &attrs, num_inputs, num_params, param_keys, param_vals); + ImperativeInvokeImpl(attrs, num_inputs, inputs, num_outputs, outputs); + API_END(); +} + +int MXCachedCreateOp(AtomicSymbolCreator creator, + int num_inputs, + int num_params, + const char **param_keys, + const char **param_vals, + CachedOpHandle *out) { + const nnvm::Op* op = static_cast(creator); + + API_BEGIN(); + nnvm::NodeAttrs *attrs = new nnvm::NodeAttrs; + SetOpAttrs(op, attrs, num_inputs, num_params, param_keys, param_vals); + *out = attrs; + API_END(); +} + +int MXCachedFree(CachedOpHandle handle) { + nnvm::NodeAttrs *attrs = static_cast(handle); + + API_BEGIN(); + delete attrs; + API_END(); +} + +int MXCachedInvoke(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs) { + nnvm::NodeAttrs *attrs = static_cast(handle); + + API_BEGIN(); + ImperativeInvokeImpl(*attrs, num_inputs, inputs, num_outputs, outputs); API_END(); } @@ -438,16 +482,31 @@ int MXAutogradMarkVariables(mx_uint num_var, int MXAutogradComputeGradient(mx_uint num_output, NDArrayHandle *output_handles) { + return MXAutogradBackward(num_output, output_handles, nullptr, 0); +} + +int MXAutogradBackward(mx_uint num_output, + NDArrayHandle *output_handles, + NDArrayHandle *ograd_handles, + int retain_graph) { API_BEGIN(); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); - std::vector outputs; + std::vector outputs, ograds; outputs.reserve(num_output); for (mx_uint i = 0; i < num_output; ++i) { outputs.emplace_back(*static_cast(output_handles[i])); } - AutogradRuntime::Get()->ComputeGradient(outputs); + ograds.reserve(num_output); + for (mx_uint i = 0; i < num_output; ++i) { + if (ograd_handles != nullptr && ograd_handles[i] != nullptr) { + ograds.emplace_back(*static_cast(ograd_handles[i])); + } else { + ograds.emplace_back(); + } + } + AutogradRuntime::Get()->ComputeGradient(outputs, ograds, retain_graph); API_END(); } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index fdf095b0..27df5b2d 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -124,6 +124,22 @@ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, API_END_HANDLE_ERROR(delete s;); } +int MXCachedCreateSymbol(CachedOpHandle handle, + const char* name, + mx_uint num_args, + SymbolHandle* args, + SymbolHandle* out) { + nnvm::Symbol *s = new nnvm::Symbol(); + const nnvm::NodeAttrs *attrs = static_cast(handle); + API_BEGIN(); + *s = nnvm::Symbol::CreateFunctor(*attrs); + nnvm::array_view parg( + (nnvm::Symbol**)args, (nnvm::Symbol**)args + num_args); // NOLINT(*) + s->Compose(parg, std::unordered_map(), name); + *out = s; + API_END_HANDLE_ERROR(delete s;) +} + int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { return NNSymbolCreateVariable(name, out); } diff --git a/src/ndarray/autograd.cc b/src/ndarray/autograd.cc index e7b57956..ce1b98f0 100644 --- a/src/ndarray/autograd.cc +++ b/src/ndarray/autograd.cc @@ -49,6 +49,10 @@ nnvm::NodeEntry AGNodeEntry::nn_entry() const { return nnvm::NodeEntry{ag_node->nn_node, index, version}; } +bool AGNodeEntry::is_none() const { + return ag_node == nullptr || ag_node->outputs.empty(); +} + AutogradRuntime::AutogradRuntime() {} void AutogradRuntime::MarkVariables( @@ -56,13 +60,21 @@ void AutogradRuntime::MarkVariables( const std::vector& grad_reqs, const std::vector& gradients) { for (uint32_t i = 0; i < variables.size(); ++i) { + std::string str_c(std::to_string(variable_count_++)); + AGNodeEntry e{AGNode::Create(Node::Create()), 0, 0}; variables[i]->entry_.clear(); - e.ag_node->outputs.push_back(*variables[i]); + e.ag_node->outputs.emplace_back(*variables[i]); + + AGNodeEntry ge{AGNode::Create(Node::Create()), 0, 0}; gradients[i]->entry_.clear(); - e.ag_node->out_grads.push_back(*gradients[i]); + ge.ag_node->outputs.emplace_back(*gradients[i]); + ge.ag_node->nn_node->attrs.name = "grad" + str_c; + gradients[i]->entry_ = std::move(ge); + e.ag_node->out_grads.emplace_back(*gradients[i]); + e.ag_node->grad_req = static_cast(grad_reqs[i]); - e.ag_node->nn_node->attrs.name = "agvar" + std::to_string(variable_count_++); + e.ag_node->nn_node->attrs.name = "var" + str_c; variables[i]->entry_ = std::move(e); // assign last to prevent cyclic reference } } @@ -102,30 +114,28 @@ AGNodePtr AutogradRuntime::RecordOp(const nnvm::Op* op, NodePtr nn_node = Node::Create(); nn_node->attrs = attrs; - nn_node->attrs.name = "agnode_" + std::to_string(node_count_++); + nn_node->attrs.name = "node_" + std::to_string(node_count_++); AGNodePtr ag_node = AGNode::Create(nn_node); ag_node->opr = opr; for (uint32_t i = 0; i < outputs.size(); ++i) { - if (outputs[i].entry_.ag_node == nullptr || - !outputs[i].entry_.ag_node->out_grads.size()) { - outputs[i].entry_.clear(); - ag_node->outputs.push_back(outputs[i]); - outputs[i].entry_ = AGNodeEntry{ag_node, i, 0}; - } else { - NDArray copy = outputs[i]; - copy.entry_.clear(); - ag_node->outputs.push_back(copy); - } + CHECK(outputs[i].entry_.is_none()) + << "Output NDArray is non-empty and already in another computation graph. " + << "Assigning to it will cause undefined behavior when evaluating gradients. " + << "Please call backward first to clear the graph or do this out side of " + << "a train section. "; + outputs[i].entry_.clear(); + ag_node->outputs.push_back(outputs[i]); + outputs[i].entry_ = AGNodeEntry{ag_node, i, 0}; } for (size_t i = 0; i < inputs.size(); ++i) { - if (inputs[i].entry_.ag_node.get() == nullptr) { + if (inputs[i].entry_.is_none()) { AGNodeEntry e{AGNode::Create(Node::Create()), 0, 0}; e.ag_node->outputs.emplace_back(inputs[i]); e.ag_node->out_grads.emplace_back(); - e.ag_node->nn_node->attrs.name = "agvar_" + std::to_string(variable_count_++); + e.ag_node->nn_node->attrs.name = "var_" + std::to_string(variable_count_++); inputs[i].entry_ = std::move(e); // assign last to prevent cyclic reference } nn_node->inputs.push_back(inputs[i].entry_.nn_entry()); @@ -135,15 +145,19 @@ AGNodePtr AutogradRuntime::RecordOp(const nnvm::Op* op, return ag_node; } -void AutogradRuntime::ComputeGradient(const std::vector& outputs) { +void AutogradRuntime::ComputeGradient(const std::vector& outputs, + const std::vector& ograds, + bool retain_graph) { static auto& fmutate_inputs = nnvm::Op::GetAttr("FMutateInputs"); std::vector heads; Symbol sym; NodeEntryMap feed_dict; for (const auto& i : outputs) { - CHECK(i.entry_.ag_node.get() != nullptr) - << "Cannot differentiate node because it doesn't have " - << "computation history. Did you forget to set is_training?"; + CHECK(!i.entry_.is_none()) + << "Cannot differentiate node because it is not in a computational graph. " + << "You need to set is_training to true or use a train_section to save " + << "computational graphs for backward. If you want to differentiate the same " + << "graph twice, you need to pass retain_graph=True to backward."; heads.emplace_back(i.entry_); sym.outputs.emplace_back(i.entry_.nn_entry()); } @@ -176,6 +190,9 @@ void AutogradRuntime::ComputeGradient(const std::vector& outputs) { if (mutable_set.count(n.get())) { aux_states.push_back(n->outputs[0]); } else { + if (n->grad_req != kNullOp) { + n->fresh_out_grad = true; + } args.push_back(n->outputs[0]); args_grad.push_back(n->out_grads[0]); grad_reqs.push_back(n->grad_req); @@ -193,19 +210,27 @@ void AutogradRuntime::ComputeGradient(const std::vector& outputs) { std::vector head_grads; head_grads.reserve(exec->head_grad_array_.size()); - - for (size_t i = 0; i < exec->output_arrays_.size(); ++i) { - NDArray grad(exec->output_arrays_[i].shape(), exec->output_arrays_[i].ctx()); - grad = static_cast(1.0); - head_grads.push_back(grad); + CHECK_EQ(ograds.size(), exec->output_arrays_.size()); + + for (size_t i = 0; i < ograds.size(); ++i) { + if (ograds[i].is_none()) { + head_grads.emplace_back( + exec->output_arrays_[i].shape(), exec->output_arrays_[i].ctx(), + false, exec->output_arrays_[i].dtype()); + head_grads.back() = static_cast(1.0); + } else { + head_grads.emplace_back(ograds[i]); + } } exec->Backward(head_grads); delete exec; } - for (auto& i : heads) { - i.ag_node->clear_history(); + if (!retain_graph) { + for (auto& i : heads) { + i.ag_node->clear_history(); + } } } diff --git a/src/ndarray/autograd.h b/src/ndarray/autograd.h index 3603b0a1..e6868064 100644 --- a/src/ndarray/autograd.h +++ b/src/ndarray/autograd.h @@ -20,6 +20,7 @@ namespace mxnet { namespace autograd { + class AGNode { public: OpReqType grad_req; @@ -28,9 +29,10 @@ class AGNode { std::vector inputs; std::vector outputs; std::vector out_grads; + bool fresh_out_grad; explicit AGNode(const nnvm::NodePtr& nn_node_) : - grad_req(kNullOp), nn_node(nn_node_) {} + grad_req(kNullOp), nn_node(nn_node_), fresh_out_grad(false) {} static AGNodePtr Create(const nnvm::NodePtr& nn_node_) { return std::make_shared(nn_node_); @@ -77,7 +79,9 @@ class AutogradRuntime { std::vector* p_inputs, std::vector* p_outputs); /*! \brief compute the gradient of outputs w.r.t variables. */ - void ComputeGradient(const std::vector& outputs); + void ComputeGradient(const std::vector& outputs, + const std::vector& ograds, + bool retain_graph); /*! \return AutogradRuntime singleton */ static AutogradRuntime* Get(); /*! \brief Get shared pointer reference to AutogradRuntime singleton. diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 717ba170..025624c9 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -26,11 +26,12 @@ namespace mxnet { NDArray NDArray::Reshape(const TShape &shape) const { using namespace autograd; - CHECK_GE(shape_.Size(), shape.Size()) - << "NDArray.Reshape: target shape size is different from current shape"; - NDArray ret = *this; - ret.shape_ = shape; if (AutogradRuntime::Get()->IsTraining()) { + CHECK_GE(shape_.Size(), shape.Size()) + << "NDArray.Reshape: target shape must have must have the same size as " + << "current shape when in train_section."; + NDArray ret = *this; + ret.shape_ = shape; // fake a Reshape op ret.entry_.clear(); const nnvm::Op* op = nnvm::Op::Get("Reshape"); @@ -47,6 +48,10 @@ NDArray NDArray::Reshape(const TShape &shape) const { op, attrs, &inputs, &outputs); return outputs[0]; } else { + CHECK_GE(shape_.Size(), shape.Size()) + << "NDArray.Reshape: target shape size is larger current shape"; + NDArray ret = *this; + ret.shape_ = shape; return ret; } } @@ -91,6 +96,20 @@ NDArray NDArray::At(index_t idx) const { } } + +bool NDArray::fresh_out_grad() const { + if (entry_.ag_node != nullptr) return entry_.ag_node->fresh_out_grad; + return false; +} + + +void NDArray::set_fresh_out_grad(bool state) const { + CHECK(entry_.ag_node != nullptr) + << "NDArray has not been marked as a variable and does not have gradient state"; + entry_.ag_node->fresh_out_grad = state; +} + + /*! * \brief run a ternary operation * \param lhs left operand diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 29f624ea..8fb324c1 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -193,10 +193,10 @@ The custom operator must be registered before it can be used. Please check the tutorial here: http://mxnet.io/how_to/new_op.html. )code") +.add_argument("data", "NDArray-or-Symbol[]", "Input data for the custom operator.") .add_argument("op_type", "string", "Name of the custom operator. " "This is the name that is passed to `mx.operator.register` " - "to register the operator.") -.add_argument("data", "NDArray-or-Symbol", "Input data for the custom operator."); + "to register the operator."); } // namespace op diff --git a/src/operator/custom/native_op.cc b/src/operator/custom/native_op.cc index 7ab0614a..2ccd286e 100644 --- a/src/operator/custom/native_op.cc +++ b/src/operator/custom/native_op.cc @@ -21,6 +21,7 @@ DMLC_REGISTER_PARAMETER(NativeOpParam); MXNET_REGISTER_OP_PROPERTY(_Native, NativeOpProp) .describe("Stub for implementing an operator implemented in native frontend language.") +.add_argument("data", "NDArray-or-Symbol[]", "Input data for the custom operator.") .add_arguments(NativeOpParam::__FIELDS__()); } // namespace op diff --git a/src/operator/custom/ndarray_op.cc b/src/operator/custom/ndarray_op.cc index 773fe775..9815f888 100644 --- a/src/operator/custom/ndarray_op.cc +++ b/src/operator/custom/ndarray_op.cc @@ -126,6 +126,7 @@ DMLC_REGISTER_PARAMETER(NDArrayOpParam); MXNET_REGISTER_OP_PROPERTY(_NDArray, NDArrayOpProp) .describe("Stub for implementing an operator implemented in native frontend language with ndarray.") +.add_argument("data", "NDArray-or-Symbol[]", "Input data for the custom operator.") .add_arguments(NDArrayOpParam::__FIELDS__()); } // namespace op diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 2288224a..13f112b6 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -20,6 +20,7 @@ namespace op { struct ReduceAxesParam : public dmlc::Parameter { TShape axis; bool keepdims; + bool exclude; DMLC_DECLARE_PARAMETER(ReduceAxesParam) { DMLC_DECLARE_FIELD(axis).set_default(TShape()) .describe(R"code(The axis or axes along which to perform the reduction. @@ -30,10 +31,15 @@ struct ReduceAxesParam : public dmlc::Parameter { If `axis` is int, a reduction is performed on a particular axis. If `axis` is a tuple of ints, a reduction is performed on all the axes - specified in the tuple.)code"); + specified in the tuple. + + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead.)code"); DMLC_DECLARE_FIELD(keepdims).set_default(false) .describe("If this is set to `True`, the reduced axes are left " "in the result as dimension with size one."); + DMLC_DECLARE_FIELD(exclude).set_default(false) + .describe("Whether to perform reduction on axis that are NOT in axis instead."); } }; @@ -150,42 +156,68 @@ inline bool ReduceAxisShape(const nnvm::NodeAttrs& attrs, return true; } -inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - if ((*in_attrs)[0].ndim() == 0) return false; - const ReduceAxesParam& param = nnvm::get(attrs.parsed); - TShape &ishape = (*in_attrs)[0]; - TShape oshape; - if (param.axis.ndim() == 0) { - if (param.keepdims) { - oshape = TShape(ishape.ndim()); +inline TShape ReduceAxesShapeImpl(const TShape& ishape, const TShape& axis, + bool keepdims, bool exclude) { + if (axis.ndim() == 0) { + if (keepdims) { + return TShape(ishape.ndim()); } else { - oshape = TShape(1); + return TShape(1); } - } else { - if (param.keepdims) { - oshape = ishape; - for (index_t i = 0; i < param.axis.ndim(); ++i) { - oshape[param.axis[i]] = 1; - } - } else { - CHECK_LT(param.axis[param.axis.ndim()-1], ishape.ndim()) - << "Reduction axis " << param.axis[param.axis.ndim()-1] - << " Exceeds input dimensions " << ishape; - oshape = TShape(std::max(1, ishape.ndim() - param.axis.ndim())); - for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) { - if (j < param.axis.ndim() && i == param.axis[j]) { + } + + CHECK_LT(axis[axis.ndim()-1], ishape.ndim()) + << "Reduction axis " << axis[axis.ndim()-1] + << " Exceeds input dimensions " << ishape; + + if (keepdims) { + TShape oshape(ishape); + if (exclude) { + for (index_t i = 0, j = 0; i < ishape.ndim(); ++i) { + if (j < axis.ndim() && i == axis[j]) { ++j; continue; } - oshape[k++] = ishape[i]; + oshape[i] = 1; } + return oshape; + } + + for (index_t i = 0; i < axis.ndim(); ++i) { + oshape[axis[i]] = 1; } + return oshape; } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + + if (exclude) { + TShape oshape = TShape(axis.ndim()); + for (index_t i = 0; i < axis.ndim(); ++i) { + oshape[i] = ishape[axis[i]]; + } + return oshape; + } + + TShape oshape = TShape(std::max(1, ishape.ndim() - axis.ndim())); + for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) { + if (j < axis.ndim() && i == axis[j]) { + ++j; + continue; + } + oshape[k++] = ishape[i]; + } + return oshape; +} + +inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if ((*in_attrs)[0].ndim() == 0) return false; + const ReduceAxesParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + ReduceAxesShapeImpl((*in_attrs)[0], param.axis, + param.keepdims, param.exclude)); return true; } @@ -332,20 +364,12 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - // using namespace mshadow; - // using namespace mshadow::expr; const ReduceAxesParam& param = nnvm::get(attrs.parsed); TShape small; - if (!param.keepdims) { - if (param.axis.ndim() == 0) { - small = TShape(inputs[0].shape_.ndim()); - } else { - small = inputs[0].shape_; - for (index_t i = 0; i < param.axis.ndim(); ++i) - small[param.axis[i]] = 1; - } - } else { + if (param.keepdims) { small = outputs[0].shape_; + } else { + small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude); } ReduceAxesComputeImpl(attrs, ctx, inputs, req, outputs, small); @@ -362,12 +386,10 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; const ReduceAxesParam& param = nnvm::get(attrs.parsed); TShape small; - if (param.axis.ndim() == 0) { - small = TShape(outputs[0].shape_.ndim()); + if (param.keepdims) { + small = inputs[0].shape_; } else { - small = outputs[0].shape_; - for (index_t i = 0; i < param.axis.ndim(); ++i) - small[param.axis[i]] = 1; + small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, param.exclude); } TShape src_shape, dst_shape; @@ -452,13 +474,12 @@ inline void ReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; const ReduceAxesParam& param = nnvm::get(attrs.parsed); TShape small; - if (param.axis.ndim() == 0) { - small = TShape(outputs[0].shape_.ndim()); + if (param.keepdims) { + small = inputs[0].shape_; } else { - small = outputs[0].shape_; - for (index_t i = 0; i < param.axis.ndim(); ++i) - small[param.axis[i]] = 1; + small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, param.exclude); } + BroadcastComputeImpl(attrs, ctx, inputs, req, outputs, small); if (normalize) { Stream *s = ctx.get_stream(); diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index ce29a2fd..073bbe16 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -113,6 +113,10 @@ MXNET_OPERATOR_REGISTER_UNARY(make_loss) // identity output as first input, but attributes are constrainted to be like rhs NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_num_inputs(2) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) .set_attr( "FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; @@ -131,7 +135,9 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) {n->inputs[1]}, nullptr, &n); lhs.push_back(nnvm::NodeEntry{ng, 0, 0}); return lhs; - }); + }) +.add_argument("lhs", "NDArray-or-Symbol", "First input.") +.add_argument("rhs", "NDArray-or-Symbol", "Second input."); DMLC_REGISTER_PARAMETER(CastParam); NNVM_REGISTER_OP(Cast) diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index c84438d7..24b417af 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -84,10 +84,84 @@ def test_training(): assert (y.asnumpy() == x.asnumpy()).all() +def test_out_grads(): + x = nd.ones((3, 5)) + dx = nd.zeros_like(x) + mark_variables([x], [dx]) + da = None + db = nd.array([1,2,3,4,5]) + dc = nd.array([5,4,3,2,1]) + + with train_section(): + a, b, c = nd.split(x, axis=0, num_outputs=3, squeeze_axis=True) + backward([a, b, c], [da, db, dc]) + + assert (dx.asnumpy() == np.array( + [[1,1,1,1,1], + [1,2,3,4,5], + [5,4,3,2,1]])).all() + + +def test_detach_updated_grad(): + x = nd.ones((2, 2)) + dx = nd.zeros_like(x) + y = nd.ones_like(x) + dy = nd.zeros_like(x) + mark_variables([x, y], [dx, dy]) + assert x._fresh_grad == False + assert y._fresh_grad == False + + with train_section(): + x2 = x + 2 + y2 = x2 + y + y2.backward() + assert (dx.asnumpy() == 1).all() + assert x._fresh_grad == True + assert y._fresh_grad == True + + dx[:] = 0 + x._fresh_grad = False + y._fresh_grad = False + assert x._fresh_grad == False + assert y._fresh_grad == False + with train_section(): + x2 = x + 2 + x2 = x2.detach() + y2 = x2 + y + y2.backward() + assert (dx.asnumpy() == 0).all() + assert y._fresh_grad == True + assert x._fresh_grad == False + + +def test_retain_grad(): + x = mx.nd.ones((2, 2)) + dx = mx.nd.zeros((2, 2)) + mark_variables([x], [dx], grad_reqs='add') + with train_section(): + y = x + 1 + y.backward(retain_graph=False) + assert (dx.asnumpy() == 1).all() + + dx[:] = 0 + with train_section(): + y = x + 1 + y.backward(retain_graph=True) + y.backward(retain_graph=False) + assert (dx.asnumpy() == 2).all() + + try: + with train_section(): + y = x + 1 + y.backward() + y.backward() + except Exception: + return + + raise AssertionError( + "differentiating the same graph twice without retain_graph should fail") + if __name__ == "__main__": - test_training() - test_unary_func() - test_binary_func() - test_operator_with_state() - test_argnum() + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_init.py b/tests/python/unittest/test_init.py index 372ad355..79862269 100644 --- a/tests/python/unittest/test_init.py +++ b/tests/python/unittest/test_init.py @@ -29,6 +29,6 @@ def test_aux_init(): if __name__ == '__main__': - test_default_init() test_variable_init() + test_default_init() test_aux_init() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index fcc7d70f..2be95a97 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -627,6 +627,17 @@ def test_iter(): assert same(y[i].asnumpy(), x[i].asnumpy()) +def test_cached(): + op = mx.nd.CachedOp('Convolution', 3, kernel=(3, 3), num_filter=10) + data = mx.nd.ones((3, 4, 10, 10)) + weight = mx.nd.ones((10, 4, 3, 3)) + bias = mx.nd.ones((10,)) + o1 = mx.nd.invoke(op, [data, weight, bias]) + bias[:] = 2 + o2 = mx.nd.invoke(op, [data, weight, bias]) + assert_almost_equal(o2.asnumpy(), o1.asnumpy()+1) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 82c20cdb..f0c4ea6b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1217,7 +1217,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape): assert_allclose(exe.grad_arrays[0].asnumpy(), out_grad_npy.reshape((5, 4, 3, 7))) def test_reduce(): - sample_num = 200 + sample_num = 500 def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym, nan_prob = 0): for i in range(sample_num): # Generate random data that has ndim between 1-7 and all the shape dims between 1-5 @@ -1226,6 +1226,7 @@ def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym, shape = np.random.randint(1, 6, size=(ndim,)) axis_num = np.random.randint(0, ndim, size=1) axis_flags = np.random.randint(0, 2, size=ndim) + exclude = np.random.randint(0, 2) axes = [] for (axis, flag) in enumerate(axis_flags): if flag: @@ -1240,6 +1241,9 @@ def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym, a = mx.symbol.Variable('a') if axes is None: b = mx_reduce_sym(a, keepdims=keepdims) + elif exclude and isinstance(axes, tuple) and len(axes) < ndim: + naxes = [i for i in range(ndim) if i not in axes] + b = mx_reduce_sym(a, axis=naxes, keepdims=keepdims, exclude=True) else: b = mx_reduce_sym(a, axis=axes, keepdims=keepdims) dat_npy = np.random.rand(*shape) @@ -1267,6 +1271,7 @@ def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym, bc_grad_groundtruth = np.broadcast_to(grad_groundtruth, grad_nd.shape) equal_backward = almost_equal_ignore_nan(grad_nd.asnumpy(), bc_grad_groundtruth, 1E-4, 1E-4) assert equal_backward + test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.sum), lambda outgrad, data, outdata, axis, keepdims, keepdim_shape: outgrad.reshape(keepdim_shape), @@ -3012,7 +3017,7 @@ def test_pick_helper(index_type=np.int32): test_pick_helper(np.int32) test_pick_helper(np.float32) - + def check_ctc_loss(acts, labels, loss_truth): in_var = mx.sym.Variable('input') labels_var = mx.sym.Variable('labels') @@ -3053,7 +3058,7 @@ def test_ctc_loss(): true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch check_ctc_loss(acts2, labels2, true_loss) - + def test_quantization_op(): min0 = mx.nd.array([0.0]) max0 = mx.nd.array([1.0]) @@ -3110,71 +3115,5 @@ def create_operator(self, ctx, shapes, dtypes): if __name__ == '__main__': - test_custom_op() - test_log_softmax() - test_new_softmax() - test_pick() - test_l2_normalization() - test_sequence_mask() - test_roipooling() - test_batchnorm_training() - test_order() - test_grid_generator() - test_dot() - test_cast() - test_clip() - test_index2d() - test_scalarop() - test_reduce() - test_init() - test_expand_dims() - test_slice_axis() - test_softmax() - test_broadcast_binary_op() - test_flip() - test_crop() - test_transpose() - test_convolution_grouping() - test_nearest_upsampling() - test_binary_op_duplicate_input() - test_elementwise_sum() - test_concat() - test_slice_channel() - test_regression() - test_python_op() - test_swapaxes() - test_scalar_pow() - test_symbol_pow() - test_pow_fn() - test_embedding() - test_rsqrt_cos_sin() - test_maximum_minimum() - test_maximum_minimum_scalar() - test_abs() - test_round_ceil_floor() - test_deconvolution() - check_softmax_with_ignore_label(default_context()) - test_convolution_dilated_impulse_response() - test_reshape() - test_broadcast() - test_stn() - test_batch_dot() - test_correlation() - test_support_vector_machine_l1_svm() - test_support_vector_machine_l2_svm() - test_pad() - test_instance_normalization() - test_mathematical() - test_special_functions_using_scipy() - test_blockgrad() - test_take() - test_bilinear_sampler() - test_binary_logic() - test_repeat() - test_tile() - test_one_hot() - test_where() - test_ctc_loss() - test_quantization_op() - test_relu() - test_sigmoid() + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index ab25f48e..28fc8a4f 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -224,17 +224,20 @@ def test_zero_prop2(): assert False + +def test_cached(): + op = mx.sym.CachedOp('Convolution', 3, kernel=(3, 3), num_filter=10) + data = mx.sym.var('data') + weight = mx.sym.var('weight') + bias = mx.sym.var('bias') + out = mx.sym.invoke(op, [data, weight, bias], 'conv') + assert out.list_arguments() == ['data', 'weight', 'bias'] + assert out.list_outputs() == ['conv_output'] + with mx.name.Prefix('test_'): + assert mx.sym.invoke(op, [data, weight, bias]).name == 'test_convolution0' + assert mx.sym.invoke(op, [data, weight, bias]).name == 'test_convolution1' + + if __name__ == '__main__': - test_zero_prop2() - test_zero_prop() - test_blockgrad() - test_symbol_children() - test_load_000800() - test_symbol_infer_shape_var() - test_symbol_infer_shape() - test_symbol_infer_type() - test_symbol_internal() - test_symbol_basic() - test_symbol_compose() - test_symbol_saveload() - test_symbol_pickle() + import nose + nose.runmodule()