Skip to content

Commit

Permalink
backward headgrads and detach (#6332)
Browse files Browse the repository at this point in the history
* backward headgrads and detach

lint

fix

add updated grad

add retain grad

exclude reduce

cpp cached invoke

cached symbol

move symbol init module

symbol cython

udpate

updated_grad->fresh_grad

fix

* fix
  • Loading branch information
piiswrong committed May 31, 2017
1 parent 215ae4a commit 32ced38
Show file tree
Hide file tree
Showing 29 changed files with 1,018 additions and 474 deletions.
68 changes: 68 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ typedef void *NDArrayHandle;
typedef const void *FunctionHandle;
/*! \brief handle to a function that takes param and creates symbol */
typedef void *AtomicSymbolCreator;
/*! \brief handle to cached operator */
typedef void *CachedOpHandle;
/*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle;
/*! \brief handle to a AtomicSymbol */
Expand Down Expand Up @@ -414,6 +416,26 @@ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle,
int *out_dev_type,
int *out_dev_id);
/*!
* \brief detach and ndarray from computation graph by clearing entry_
* \param handle NDArray handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out);
/*!
* \brief set the flag for gradient array state.
* \param handle NDArray handle
* \param state the new state.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySetGradState(NDArrayHandle handle, int state);
/*!
* \brief set the flag for gradient array state.
* \param handle NDArray handle
* \param state the new state.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
//--------------------------------
// Part 2: functions on NDArray
//--------------------------------
Expand Down Expand Up @@ -548,6 +570,39 @@ MXNET_DLL int MXAutogradMarkVariables(mx_uint num_var,
*/
MXNET_DLL int MXAutogradComputeGradient(mx_uint num_output,
NDArrayHandle* output_handles);
/*!
* \brief compute the gradient of outputs w.r.t variabels
* \param num_output number of output NDArray
* \param output_handles output NDArrays
* \param ograd_handles head gradient for NDArrays
* \param retain_graph whether to keep the graph after backward
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXAutogradBackward(mx_uint num_output,
NDArrayHandle* output_handles,
NDArrayHandle* ograd_handles,
int retain_graph);
/*!
* \brief create cached operator
*/
MXNET_DLL int MXCachedCreateOp(AtomicSymbolCreator creator,
int num_inputs,
int num_params,
const char **param_keys,
const char **param_vals,
CachedOpHandle *out);
/*!
* \brief free cached operator
*/
MXNET_DLL int MXCachedFree(CachedOpHandle handle);
/*!
* \brief invoke cached operator
*/
MXNET_DLL int MXCachedInvoke(CachedOpHandle handle,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs);
//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down Expand Up @@ -615,6 +670,19 @@ MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
const char **keys,
const char **vals,
SymbolHandle *out);
/*!
* \brief Create an AtomicSymbol from cached op.
* \param handle cached node attribute.
* \param name name of new symbol.
* \param num_args the number of symbol arguments
* \param args symbol arguments
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXCachedCreateSymbol(CachedOpHandle handle,
const char* name,
mx_uint num_args,
SymbolHandle* args,
SymbolHandle* out);
/*!
* \brief Create a Variable Symbol.
* \param name name of the variable
Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class AGNodeEntry {
}

nnvm::NodeEntry nn_entry() const;
bool is_none() const;
};

class AutogradRuntime;
Expand Down Expand Up @@ -149,6 +150,10 @@ class NDArray {
inline bool is_none() const {
return ptr_.get() == nullptr;
}
/*! \return updated grad state in entry_ */
bool fresh_out_grad() const;
/*! \return updated grad state in entry_ */
void set_fresh_out_grad(bool state) const;
/*!
* \brief Block until all the pending write operations with respect
* to current NDArray are finished, and read can be performed.
Expand Down Expand Up @@ -321,6 +326,14 @@ class NDArray {
* \return NDArray in new shape
*/
NDArray Reshape(const TShape &shape) const;
/*!
* \brief Return a copy of this NDArray without autograd history
*/
NDArray Detach() const {
NDArray ret(*this);
ret.entry_ = autograd::AGNodeEntry{nullptr, 0, 0};
return ret;
}
/*!
* \brief Allocate the space if it is delayed allocated.
* This is an internal function used by system that normal user should not use
Expand Down
2 changes: 1 addition & 1 deletion nnvm
30 changes: 30 additions & 0 deletions python/mxnet/_ctypes/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# coding: utf-8
"""Common code between symbolic and ndarray."""
from __future__ import absolute_import as _abs

import ctypes

from ..base import _LIB
from ..base import c_array, c_str
from ..base import OpHandle, CachedOpHandle
from ..base import check_call


class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle", "op"]
def __init__(self, op, num_input, **kwargs):
self.op = op
op_handle = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(op), ctypes.byref(op_handle)))
self.handle = CachedOpHandle()
check_call(_LIB.MXCachedCreateOp(
op_handle,
ctypes.c_int(num_input),
ctypes.c_int(len(kwargs)),
c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]),
c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]),
ctypes.byref(self.handle)))

def __del__(self):
check_call(_LIB.MXCachedFree(self.handle))
33 changes: 32 additions & 1 deletion python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments
# pylint: disable=global-statement, unused-import
"""Symbolic configuration API."""
"""NDArray configuration API."""
from __future__ import absolute_import as _abs

import ctypes
Expand All @@ -13,6 +13,7 @@
from ..base import NDArrayHandle, OpHandle
from ..base import check_call
from ..ndarray_doc import _build_doc
from .common import CachedOp


class NDArrayBase(object):
Expand Down Expand Up @@ -78,3 +79,33 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle))
for i in range(num_output.value)]


def invoke(cached_op, args, out=None, name=None): # pylint: disable=unused-argument
"""ctypes implementation of imperative invoke wrapper"""
if out is not None:
original_output = out
if isinstance(out, NDArrayBase):
out = (out,)
num_output = ctypes.c_int(len(out))
output_vars = c_array(NDArrayHandle, [i.handle for i in out])
output_vars = ctypes.cast(output_vars, ctypes.POINTER(NDArrayHandle))
else:
original_output = None
output_vars = ctypes.POINTER(NDArrayHandle)()
num_output = ctypes.c_int(0)

check_call(_LIB.MXCachedInvoke(
cached_op.handle,
ctypes.c_int(len(args)),
c_array(NDArrayHandle, [arr.handle for arr in args]),
ctypes.byref(num_output),
ctypes.byref(output_vars)))

if original_output is not None:
return original_output
if num_output.value == 1:
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle))
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle))
for i in range(num_output.value)]
163 changes: 38 additions & 125 deletions python/mxnet/_ctypes/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
from __future__ import absolute_import as _abs

import ctypes
import sys
import numpy as _numpy
from ..base import _LIB
from ..base import c_array, c_str, mx_uint, py_str
from ..base import SymbolHandle, OpHandle
from ..base import c_array, c_str, mx_uint
from ..base import SymbolHandle
from ..base import check_call
from ..symbol_doc import _build_doc
from ..name import NameManager
from ..attribute import AttrScope
from .common import CachedOp # pylint: disable=unused-import

_symbol_cls = None

Expand Down Expand Up @@ -105,122 +102,38 @@ def _set_symbol_class(cls):
_symbol_cls = cls


def _make_atomic_symbol_function(handle, name):
"""Create an atomic symbol function by handle and funciton name."""
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = mx_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
key_var_num_args = ctypes.c_char_p()
ret_type = ctypes.c_char_p()

check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(key_var_num_args),
ctypes.byref(ret_type)))
narg = int(num_args.value)
func_name = name
key_var_num_args = py_str(key_var_num_args.value)
ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
doc_str = _build_doc(func_name,
py_str(desc.value),
[py_str(arg_names[i]) for i in range(narg)],
[py_str(arg_types[i]) for i in range(narg)],
[py_str(arg_descs[i]) for i in range(narg)],
key_var_num_args,
ret_type)

def creator(*args, **kwargs):
"""Activation Operator of Neural Net.
The parameters listed below can be passed in as keyword arguments.
Parameters
----------
name : string, required.
Name of the resulting symbol.
Returns
-------
symbol: Symbol
the resulting symbol
"""
param_keys = []
param_vals = []
symbol_kwargs = {}

attr = kwargs.pop('attr', None)
kwargs.update(AttrScope.current.get(attr))
name = kwargs.pop('name', None)
if 'dtype' in kwargs:
kwargs['dtype'] = _numpy.dtype(kwargs['dtype']).name

if key_var_num_args and key_var_num_args not in kwargs:
param_keys.append(c_str(key_var_num_args))
param_vals.append(c_str(str(len(args))))

for k, v in kwargs.items():
if isinstance(v, SymbolBase):
symbol_kwargs[k] = v
else:
param_keys.append(c_str(k))
param_vals.append(c_str(str(v)))
# create atomic symbol
param_keys = c_array(ctypes.c_char_p, param_keys)
param_vals = c_array(ctypes.c_char_p, param_vals)
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
handle,
mx_uint(len(param_keys)),
param_keys, param_vals,
ctypes.byref(sym_handle)))

if len(args) != 0 and len(symbol_kwargs) != 0:
raise TypeError(
'%s can only accept input'
'Symbols either as positional or keyword arguments, not both' % func_name)
s = _symbol_cls(sym_handle)

hint = func_name.lower()
name = NameManager.current.get(name, hint)
s._compose(*args, name=name, **symbol_kwargs)
return s

creator.__name__ = func_name
creator.__doc__ = doc_str
creator.__module__ = 'mxnet.symbol'
return creator


def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
_set_symbol_class(symbol_class)
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()

check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
op_names.append(py_str(plist[i]))

module_obj = sys.modules["%s.symbol" % root_namespace]
module_internal = sys.modules["%s._symbol_internal" % root_namespace]
module_contrib = sys.modules["%s.contrib.symbol" % root_namespace]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_contrib_'):
function.__name__ = function.__name__[9:]
function.__module__ = 'mxnet.contrib.symbol'
setattr(module_contrib, function.__name__, function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
def invoke(cached_op, args, name=None):
"""Call cached symbolic operator"""
ret = SymbolHandle()
hint = cached_op.op.lower()
name = c_str(NameManager.current.get(name, hint))
check_call(_LIB.MXCachedCreateSymbol(
cached_op.handle,
name,
mx_uint(len(args)),
c_array(SymbolHandle, [s.handle for s in args]),
ctypes.byref(ret)))
return _symbol_cls(ret)


def _symbol_creator(handle, args, kwargs, keys, vals, name):
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
ctypes.c_void_p(handle),
mx_uint(len(keys)),
c_array(ctypes.c_char_p, [c_str(i) for i in keys]),
c_array(ctypes.c_char_p, [c_str(str(i)) for i in vals]),
ctypes.byref(sym_handle)))

if args and kwargs:
raise TypeError(
'Operators with variable length input can only accept input'
'Symbols either as positional or keyword arguments, not both')
s = _symbol_cls(sym_handle)
if args:
s._compose(*args, name=name)
elif kwargs:
s._compose(name=name, **kwargs)
else:
s._compose(name=name)
return s
1 change: 1 addition & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _load_lib():
NDArrayHandle = ctypes.c_void_p
FunctionHandle = ctypes.c_void_p
OpHandle = ctypes.c_void_p
CachedOpHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
ExecutorHandle = ctypes.c_void_p
DataIterCreatorHandle = ctypes.c_void_p
Expand Down
Loading

0 comments on commit 32ced38

Please sign in to comment.