Skip to content

Commit

Permalink
ndarray ops
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Sep 21, 2016
1 parent 7a2432c commit e6b0508
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 138 deletions.
2 changes: 1 addition & 1 deletion src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ macro mx_define_handle_t(name, destructor)
end

@mx_define_handle_t(MX_NDArrayHandle, MXNDArrayFree)
@mx_define_handle_t(MX_FunctionHandle, nop)
@mx_define_handle_t(MX_OpHandle, nop)
@mx_define_handle_t(MX_SymbolHandle, MXSymbolFree)
@mx_define_handle_t(MX_ExecutorHandle, MXExecutorFree)
@mx_define_handle_t(MX_DataIterHandle, MXDataIterFree)
Expand Down
197 changes: 62 additions & 135 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ scenarios are supported
"""
function setindex!(arr :: NDArray, val :: Real, ::Colon)
@assert(arr.writable)
_set_value(convert(eltype(arr), val), arr)
_set_value(out=arr, src=convert(eltype(arr), val))
return arr
end
function setindex!{T<:Real}(arr :: NDArray, val :: Array{T}, ::Colon)
Expand Down Expand Up @@ -948,156 +948,83 @@ Those functions always return the output arguments. If there is only one output
object (`NDArray`) is returned. Otherwise, a tuple containing all the outputs will be returned.
"""

function _get_ndarray_functions()
n = Ref{MX_uint}(0)
handles = Ref{Ptr{MX_handle}}(0)

@mxcall(:MXListFunctions, (Ref{MX_uint}, Ref{Ptr{MX_handle}}), n, handles)

funcs = unsafe_wrap(Array, handles[], n[])
return funcs
end

const _function_cache = Dict{Symbol, MX_handle}()
function _get_function(name :: Symbol)
if !haskey(_function_cache, name)
handle = Ref{MX_handle}(0)

@mxcall(:MXGetFunction, (Cstring, Ref{MX_handle}), name, handle)
_function_cache[name] = handle[]
return handle[]
else
return _function_cache[name]
end
end

function _get_function_description(handle :: MX_handle)
# get function information (human readable)
ref_name = Ref{char_p}(0)
ref_desc = Ref{char_p}(0)
ref_narg = Ref{MX_uint}(0)

ref_arg_names = Ref{char_pp}(0)
ref_arg_types = Ref{char_pp}(0)
ref_arg_descs = Ref{char_pp}(0)

ref_ret_type = Ref{char_p}(0)

@mxcall(:MXFuncGetInfo,
(MX_handle, Ref{char_p}, Ref{char_p}, Ref{MX_uint}, Ref{char_pp},
Ref{char_pp}, Ref{char_pp}, Ref{char_p}),
handle, ref_name, ref_desc, ref_narg, ref_arg_names,
ref_arg_types, ref_arg_descs, ref_ret_type)

name = Symbol(unsafe_wrap(String, ref_name[]))
signature = _format_signature(Int(ref_narg[]), ref_arg_names)
desc = " " * string(name) * "(" * signature * ")\n\n"
desc *= unsafe_wrap(String, ref_desc[]) * "\n\n"
desc *= "# Arguments\n"
desc *= _format_docstring(Int(ref_narg[]), ref_arg_names, ref_arg_types, ref_arg_descs)
return name, desc
end

function _get_function_expressions(handle :: MX_handle, name)
# get function specification
ref_n_use_vars = Ref{MX_uint}(0)
ref_n_scalars = Ref{MX_uint}(0)
ref_n_mut_vars = Ref{MX_uint}(0)
ref_type_mask = Ref{Cint}(0)
@mxcall(:MXFuncDescribe,
(MX_handle, Ref{MX_uint}, Ref{MX_uint}, Ref{MX_uint}, Ref{Cint}),
handle, ref_n_use_vars, ref_n_scalars, ref_n_mut_vars, ref_type_mask)

n_used_vars = ref_n_use_vars[]
n_scalars = ref_n_scalars[]
n_mutate_vars = ref_n_mut_vars[]
type_mask = ref_type_mask[]
accept_empty_mutate = (type_mask & convert(Cint,ACCEPT_EMPTY_MUTATE_TARGET)) != 0
arg_before_scalar = (type_mask & convert(Cint,NDARRAY_ARG_BEFORE_SCALAR)) != 0

# general ndarray function
if arg_before_scalar
args = vcat([Expr(:(::), Symbol("in$i"), NDArray) for i=1:n_used_vars],
[Expr(:(::), Symbol("sca$i"), Real) for i=1:n_scalars],
[Expr(:(::), Symbol("out$i"), NDArray) for i=1:n_mutate_vars])
else
args = vcat([Expr(:(::), Symbol("sca$i"), Real) for i=1:n_scalars],
[Expr(:(::), Symbol("in$i"), NDArray) for i=1:n_used_vars],
[Expr(:(::), Symbol("out$i"), NDArray) for i=1:n_mutate_vars])
end

_use_vars = Expr(:ref, :MX_handle, [Symbol("in$i") for i=1:n_used_vars]...)
_scalars = Expr(:ref, :MX_float, [Symbol("sca$i") for i=1:n_scalars]...)
_mut_vars = Expr(:ref, :MX_handle, [Symbol("out$i") for i=1:n_mutate_vars]...)

# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
if name == :dot
_use_vars.args[2:end] = flipdim(_use_vars.args[2:end], 1)
end

# XXX: hacky way of solving the semantic difference of the axes parameter in Julia
# and in libmxnet.
# See https://github.com/dmlc/MXNet.jl/pull/123
if name == :transpose
transform = quote
kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs]
end
else
transform = :()
end

stmt_call = quote
local handle = _get_function($(QuoteNode(name)))
_invoke_mxfunction(handle, $_use_vars, $_scalars, $_mut_vars; kwargs...)
end
if n_mutate_vars == 1
stmt_ret = :(return out1)
else
stmt_ret = Expr(:return, Expr(:tuple, [Symbol("out$i") for i=1:n_mutate_vars]...))
end
function _get_ndarray_function_def(name :: String)
func_name = Symbol(name)

func_def = quote
function $name($(args...); kwargs...)
$transform
$stmt_call
$stmt_ret
end
end
function $func_name(args::NDArray...; out=nothing, kwargs...)
if out != nothing
output_vars = out
if isa(output_vars, NDArray)
output_vars = NDArray[output_vars]
end
num_outputs = length(output_vars)
else
output_vars = NDArray[]
num_outputs = 0
end

# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
if $name == "dot"
args = flipdim(args, 1)
end

if accept_empty_mutate
args0 = args[1:n_used_vars+n_scalars]
_mut_vars0 = [:(NDArray(_ndarray_alloc())) for i=1:n_mutate_vars]
# XXX: hacky way of solving the semantic difference of the axes parameter in Julia
# and in libmxnet.
# See https://github.com/dmlc/MXNet.jl/pull/123
if $name == "transpose"
kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs]
end

func_def0 = quote
function $name($(args0...); kwargs...)
$name($(args0...), $(_mut_vars0...); kwargs...)
output_handles = [Base.cconvert(MX_handle, x) for x in output_vars]
output_handles_pp = [Base.cconvert(Ptr{MX_handle}, output_handles)]
num_outputs_p = [convert(Cint, num_outputs)]

kw_keys_str = String[string(x[1]) for x in kwargs]
kw_vals_str = String[string(x[2]) for x in kwargs]

args = collect(args) # tuple to list
op_handle = _get_cached_libmx_op_handle($(QuoteNode(name)))
@mxcall(:MXImperativeInvoke,
(MX_handle, Cint, Ptr{MX_handle},
Ptr{Cint}, Ptr{Ptr{MX_handle}},
Cint, char_pp, char_pp),
op_handle, length(args), args,
num_outputs_p, output_handles_pp,
length(kwargs), kw_keys_str, kw_vals_str)

if out == nothing
handle_array = unsafe_wrap(Array, output_handles_pp[], num_outputs_p[])
arrays = [NDArray(hdr) for hdr in handle_array]
if mx_num_outputs == 1
return arrays[1]
else
return arrays
end
else
return out
end
end
return func_def, func_def0
else
return func_def, :()
end

return func_def
end

macro _import_ndarray_functions()
funcs = _get_ndarray_functions()
func_exprs = Expr[]
names = _get_libmx_op_names()
func_exprs = map(names) do name
op_handle = _get_libmx_op_handle(name)

for i = 1:length(funcs)
handle = funcs[i]

name, desc = _get_function_description(handle)
func_def, func_def0 = _get_function_expressions(handle, name)
desc, key_narg = _get_libmx_op_description(name, op_handle)
func_def = _get_ndarray_function_def(name)

func_name = Symbol(name)
expr = quote
$(isdefined(Base, name) ? :(import Base.$name) : :())
$(isdefined(Base, func_name) ? :(import Base.$func_name) : :())
@doc $desc ->
$func_def
$func_def0
end

push!(func_exprs, expr)
end

esc(quote
Expand Down
4 changes: 2 additions & 2 deletions src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ function _define_atomic_symbol_creator(hdr :: MX_handle)
else
name = ""
end

# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
if $func_name_s == "dot"
Expand Down Expand Up @@ -755,7 +755,7 @@ macro _import_atomic_symbol_creators()
end)
end

@_import_atomic_symbol_creators()
#@_import_atomic_symbol_creators()

################################################################################
# Utility macros to chain up symbols
Expand Down
70 changes: 70 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,76 @@ end
################################################################################
# Internal Utilities
################################################################################
function _get_libmx_op_names()
n = Ref{MX_uint}(0)
names = Ref{char_pp}(0)

@mxcall(:MXListAllOpNames, (Ref{MX_uint}, Ref{char_pp}), n, names)

names = unsafe_wrap(Array, names[], n[])
return [unsafe_string(x) for x in names]
end
function _get_libmx_op_handle(name :: String)
handle = Ref{MX_handle}(0)
@mxcall(:NNGetOpHandle, (char_p, Ref{MX_handle}), name, handle)
return MX_OpHandle(handle[])
end

# We keep a cache and retrieve the address everytime
# we run Julia, instead of pre-compiling with macro,
# because the actual handle might change in different
# runs
const _libmx_op_cache = Dict{String, MX_OpHandle}()
function _get_cached_libmx_op_handle(name :: String)
if !haskey(_libmx_op_cache, name)
handle = _get_libmx_op_handle(name)
_libmx_op_cache[name] = handle
return handle
else
return _libmx_op_cache[name]
end
end

function _get_libmx_op_description(name :: String, handle :: MX_OpHandle)
# get operator information (human readable)
ref_real_name = Ref{char_p}(0)
ref_desc = Ref{char_p}(0)
ref_narg = Ref{MX_uint}(0)

ref_arg_names = Ref{char_pp}(0)
ref_arg_types = Ref{char_pp}(0)
ref_arg_descs = Ref{char_pp}(0)

ref_key_narg = Ref{char_p}(0)
ref_ret_type = Ref{char_p}(0)

@mxcall(:MXSymbolGetAtomicSymbolInfo,
(MX_handle, Ref{char_p}, Ref{char_p}, Ref{MX_uint}, Ref{char_pp},
Ref{char_pp}, Ref{char_pp}, Ref{char_p}, Ref{char_p}),
handle, ref_real_name, ref_desc, ref_narg, ref_arg_names,
ref_arg_types, ref_arg_descs, ref_key_narg, ref_ret_type)

real_name = unsafe_string(ref_real_name[])
signature = _format_signature(Int(ref_narg[]), ref_arg_names)
desc = " " * name * "(" * signature * ")\n\n"
if real_name != name
desc *= name * " is an alias of " * real_name * ".\n\n"
end

key_narg = unsafe_string(ref_key_narg[])
if key_narg != ""
desc *= "**Note**: " * name * " takes variable number of positional inputs. "
desc *= "So instead of calling as $name([x, y, z], $key_narg=3), "
desc *= "one should call via $name(x, y, z), and $key_narg will be "
desc *= "determined automatically.\n\n"
end

desc *= unsafe_string(ref_desc[]) * "\n\n"
desc *= "# Arguments\n"
desc *= _format_docstring(Int(ref_narg[]), ref_arg_names, ref_arg_types, ref_arg_descs)
return desc, key_narg
end

function _format_typestring(typestr :: String)
replace(typestr, r"\bSymbol\b", "SymbolicNode")
end
Expand Down

0 comments on commit e6b0508

Please sign in to comment.