From e6b050805ea81535232c52c4847b8478e54f318a Mon Sep 17 00:00:00 2001 From: Chiyuan Zhang Date: Wed, 21 Sep 2016 17:58:08 -0400 Subject: [PATCH] ndarray ops --- src/base.jl | 2 +- src/ndarray.jl | 197 ++++++++++++++----------------------------- src/symbolic-node.jl | 4 +- src/util.jl | 70 +++++++++++++++ 4 files changed, 135 insertions(+), 138 deletions(-) diff --git a/src/base.jl b/src/base.jl index cca45c273b96..a864125757cd 100644 --- a/src/base.jl +++ b/src/base.jl @@ -102,7 +102,7 @@ macro mx_define_handle_t(name, destructor) end @mx_define_handle_t(MX_NDArrayHandle, MXNDArrayFree) -@mx_define_handle_t(MX_FunctionHandle, nop) +@mx_define_handle_t(MX_OpHandle, nop) @mx_define_handle_t(MX_SymbolHandle, MXSymbolFree) @mx_define_handle_t(MX_ExecutorHandle, MXExecutorFree) @mx_define_handle_t(MX_DataIterHandle, MXDataIterFree) diff --git a/src/ndarray.jl b/src/ndarray.jl index 2e2c806552f7..c1cd7ccf16b2 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -333,7 +333,7 @@ scenarios are supported """ function setindex!(arr :: NDArray, val :: Real, ::Colon) @assert(arr.writable) - _set_value(convert(eltype(arr), val), arr) + _set_value(out=arr, src=convert(eltype(arr), val)) return arr end function setindex!{T<:Real}(arr :: NDArray, val :: Array{T}, ::Colon) @@ -948,156 +948,83 @@ Those functions always return the output arguments. If there is only one output object (`NDArray`) is returned. Otherwise, a tuple containing all the outputs will be returned. """ -function _get_ndarray_functions() - n = Ref{MX_uint}(0) - handles = Ref{Ptr{MX_handle}}(0) - - @mxcall(:MXListFunctions, (Ref{MX_uint}, Ref{Ptr{MX_handle}}), n, handles) - - funcs = unsafe_wrap(Array, handles[], n[]) - return funcs -end - -const _function_cache = Dict{Symbol, MX_handle}() -function _get_function(name :: Symbol) - if !haskey(_function_cache, name) - handle = Ref{MX_handle}(0) - - @mxcall(:MXGetFunction, (Cstring, Ref{MX_handle}), name, handle) - _function_cache[name] = handle[] - return handle[] - else - return _function_cache[name] - end -end - -function _get_function_description(handle :: MX_handle) - # get function information (human readable) - ref_name = Ref{char_p}(0) - ref_desc = Ref{char_p}(0) - ref_narg = Ref{MX_uint}(0) - - ref_arg_names = Ref{char_pp}(0) - ref_arg_types = Ref{char_pp}(0) - ref_arg_descs = Ref{char_pp}(0) - - ref_ret_type = Ref{char_p}(0) - - @mxcall(:MXFuncGetInfo, - (MX_handle, Ref{char_p}, Ref{char_p}, Ref{MX_uint}, Ref{char_pp}, - Ref{char_pp}, Ref{char_pp}, Ref{char_p}), - handle, ref_name, ref_desc, ref_narg, ref_arg_names, - ref_arg_types, ref_arg_descs, ref_ret_type) - - name = Symbol(unsafe_wrap(String, ref_name[])) - signature = _format_signature(Int(ref_narg[]), ref_arg_names) - desc = " " * string(name) * "(" * signature * ")\n\n" - desc *= unsafe_wrap(String, ref_desc[]) * "\n\n" - desc *= "# Arguments\n" - desc *= _format_docstring(Int(ref_narg[]), ref_arg_names, ref_arg_types, ref_arg_descs) - return name, desc -end - -function _get_function_expressions(handle :: MX_handle, name) - # get function specification - ref_n_use_vars = Ref{MX_uint}(0) - ref_n_scalars = Ref{MX_uint}(0) - ref_n_mut_vars = Ref{MX_uint}(0) - ref_type_mask = Ref{Cint}(0) - @mxcall(:MXFuncDescribe, - (MX_handle, Ref{MX_uint}, Ref{MX_uint}, Ref{MX_uint}, Ref{Cint}), - handle, ref_n_use_vars, ref_n_scalars, ref_n_mut_vars, ref_type_mask) - - n_used_vars = ref_n_use_vars[] - n_scalars = ref_n_scalars[] - n_mutate_vars = ref_n_mut_vars[] - type_mask = ref_type_mask[] - accept_empty_mutate = (type_mask & convert(Cint,ACCEPT_EMPTY_MUTATE_TARGET)) != 0 - arg_before_scalar = (type_mask & convert(Cint,NDARRAY_ARG_BEFORE_SCALAR)) != 0 - - # general ndarray function - if arg_before_scalar - args = vcat([Expr(:(::), Symbol("in$i"), NDArray) for i=1:n_used_vars], - [Expr(:(::), Symbol("sca$i"), Real) for i=1:n_scalars], - [Expr(:(::), Symbol("out$i"), NDArray) for i=1:n_mutate_vars]) - else - args = vcat([Expr(:(::), Symbol("sca$i"), Real) for i=1:n_scalars], - [Expr(:(::), Symbol("in$i"), NDArray) for i=1:n_used_vars], - [Expr(:(::), Symbol("out$i"), NDArray) for i=1:n_mutate_vars]) - end - - _use_vars = Expr(:ref, :MX_handle, [Symbol("in$i") for i=1:n_used_vars]...) - _scalars = Expr(:ref, :MX_float, [Symbol("sca$i") for i=1:n_scalars]...) - _mut_vars = Expr(:ref, :MX_handle, [Symbol("out$i") for i=1:n_mutate_vars]...) - - # XXX: hacky way of solving the problem that the arguments of `dot` should be swapped - # See https://github.com/dmlc/MXNet.jl/issues/55 - if name == :dot - _use_vars.args[2:end] = flipdim(_use_vars.args[2:end], 1) - end - - # XXX: hacky way of solving the semantic difference of the axes parameter in Julia - # and in libmxnet. - # See https://github.com/dmlc/MXNet.jl/pull/123 - if name == :transpose - transform = quote - kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs] - end - else - transform = :() - end - - stmt_call = quote - local handle = _get_function($(QuoteNode(name))) - _invoke_mxfunction(handle, $_use_vars, $_scalars, $_mut_vars; kwargs...) - end - if n_mutate_vars == 1 - stmt_ret = :(return out1) - else - stmt_ret = Expr(:return, Expr(:tuple, [Symbol("out$i") for i=1:n_mutate_vars]...)) - end +function _get_ndarray_function_def(name :: String) + func_name = Symbol(name) func_def = quote - function $name($(args...); kwargs...) - $transform - $stmt_call - $stmt_ret - end - end + function $func_name(args::NDArray...; out=nothing, kwargs...) + if out != nothing + output_vars = out + if isa(output_vars, NDArray) + output_vars = NDArray[output_vars] + end + num_outputs = length(output_vars) + else + output_vars = NDArray[] + num_outputs = 0 + end + + # XXX: hacky way of solving the problem that the arguments of `dot` should be swapped + # See https://github.com/dmlc/MXNet.jl/issues/55 + if $name == "dot" + args = flipdim(args, 1) + end - if accept_empty_mutate - args0 = args[1:n_used_vars+n_scalars] - _mut_vars0 = [:(NDArray(_ndarray_alloc())) for i=1:n_mutate_vars] + # XXX: hacky way of solving the semantic difference of the axes parameter in Julia + # and in libmxnet. + # See https://github.com/dmlc/MXNet.jl/pull/123 + if $name == "transpose" + kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs] + end - func_def0 = quote - function $name($(args0...); kwargs...) - $name($(args0...), $(_mut_vars0...); kwargs...) + output_handles = [Base.cconvert(MX_handle, x) for x in output_vars] + output_handles_pp = [Base.cconvert(Ptr{MX_handle}, output_handles)] + num_outputs_p = [convert(Cint, num_outputs)] + + kw_keys_str = String[string(x[1]) for x in kwargs] + kw_vals_str = String[string(x[2]) for x in kwargs] + + args = collect(args) # tuple to list + op_handle = _get_cached_libmx_op_handle($(QuoteNode(name))) + @mxcall(:MXImperativeInvoke, + (MX_handle, Cint, Ptr{MX_handle}, + Ptr{Cint}, Ptr{Ptr{MX_handle}}, + Cint, char_pp, char_pp), + op_handle, length(args), args, + num_outputs_p, output_handles_pp, + length(kwargs), kw_keys_str, kw_vals_str) + + if out == nothing + handle_array = unsafe_wrap(Array, output_handles_pp[], num_outputs_p[]) + arrays = [NDArray(hdr) for hdr in handle_array] + if mx_num_outputs == 1 + return arrays[1] + else + return arrays + end + else + return out end end - return func_def, func_def0 - else - return func_def, :() end + + return func_def end macro _import_ndarray_functions() - funcs = _get_ndarray_functions() - func_exprs = Expr[] + names = _get_libmx_op_names() + func_exprs = map(names) do name + op_handle = _get_libmx_op_handle(name) - for i = 1:length(funcs) - handle = funcs[i] - - name, desc = _get_function_description(handle) - func_def, func_def0 = _get_function_expressions(handle, name) + desc, key_narg = _get_libmx_op_description(name, op_handle) + func_def = _get_ndarray_function_def(name) + func_name = Symbol(name) expr = quote - $(isdefined(Base, name) ? :(import Base.$name) : :()) + $(isdefined(Base, func_name) ? :(import Base.$func_name) : :()) @doc $desc -> $func_def - $func_def0 end - - push!(func_exprs, expr) end esc(quote diff --git a/src/symbolic-node.jl b/src/symbolic-node.jl index dfc54c3c3b1c..c1e6f7d8e8c6 100644 --- a/src/symbolic-node.jl +++ b/src/symbolic-node.jl @@ -629,7 +629,7 @@ function _define_atomic_symbol_creator(hdr :: MX_handle) else name = "" end - + # XXX: hacky way of solving the problem that the arguments of `dot` should be swapped # See https://github.com/dmlc/MXNet.jl/issues/55 if $func_name_s == "dot" @@ -755,7 +755,7 @@ macro _import_atomic_symbol_creators() end) end -@_import_atomic_symbol_creators() +#@_import_atomic_symbol_creators() ################################################################################ # Utility macros to chain up symbols diff --git a/src/util.jl b/src/util.jl index a53647790c06..11d12e7f9dd2 100644 --- a/src/util.jl +++ b/src/util.jl @@ -62,6 +62,76 @@ end ################################################################################ # Internal Utilities ################################################################################ +function _get_libmx_op_names() + n = Ref{MX_uint}(0) + names = Ref{char_pp}(0) + + @mxcall(:MXListAllOpNames, (Ref{MX_uint}, Ref{char_pp}), n, names) + + names = unsafe_wrap(Array, names[], n[]) + return [unsafe_string(x) for x in names] +end +function _get_libmx_op_handle(name :: String) + handle = Ref{MX_handle}(0) + @mxcall(:NNGetOpHandle, (char_p, Ref{MX_handle}), name, handle) + return MX_OpHandle(handle[]) +end + +# We keep a cache and retrieve the address everytime +# we run Julia, instead of pre-compiling with macro, +# because the actual handle might change in different +# runs +const _libmx_op_cache = Dict{String, MX_OpHandle}() +function _get_cached_libmx_op_handle(name :: String) + if !haskey(_libmx_op_cache, name) + handle = _get_libmx_op_handle(name) + _libmx_op_cache[name] = handle + return handle + else + return _libmx_op_cache[name] + end +end + +function _get_libmx_op_description(name :: String, handle :: MX_OpHandle) + # get operator information (human readable) + ref_real_name = Ref{char_p}(0) + ref_desc = Ref{char_p}(0) + ref_narg = Ref{MX_uint}(0) + + ref_arg_names = Ref{char_pp}(0) + ref_arg_types = Ref{char_pp}(0) + ref_arg_descs = Ref{char_pp}(0) + + ref_key_narg = Ref{char_p}(0) + ref_ret_type = Ref{char_p}(0) + + @mxcall(:MXSymbolGetAtomicSymbolInfo, + (MX_handle, Ref{char_p}, Ref{char_p}, Ref{MX_uint}, Ref{char_pp}, + Ref{char_pp}, Ref{char_pp}, Ref{char_p}, Ref{char_p}), + handle, ref_real_name, ref_desc, ref_narg, ref_arg_names, + ref_arg_types, ref_arg_descs, ref_key_narg, ref_ret_type) + + real_name = unsafe_string(ref_real_name[]) + signature = _format_signature(Int(ref_narg[]), ref_arg_names) + desc = " " * name * "(" * signature * ")\n\n" + if real_name != name + desc *= name * " is an alias of " * real_name * ".\n\n" + end + + key_narg = unsafe_string(ref_key_narg[]) + if key_narg != "" + desc *= "**Note**: " * name * " takes variable number of positional inputs. " + desc *= "So instead of calling as $name([x, y, z], $key_narg=3), " + desc *= "one should call via $name(x, y, z), and $key_narg will be " + desc *= "determined automatically.\n\n" + end + + desc *= unsafe_string(ref_desc[]) * "\n\n" + desc *= "# Arguments\n" + desc *= _format_docstring(Int(ref_narg[]), ref_arg_names, ref_arg_types, ref_arg_descs) + return desc, key_narg +end + function _format_typestring(typestr :: String) replace(typestr, r"\bSymbol\b", "SymbolicNode") end