Skip to content

Commit

Permalink
Merge pull request apache#123 from vchuravy/vc/ndarray_funcs
Browse files Browse the repository at this point in the history
handle kwargs for ndarray functions
  • Loading branch information
pluskid authored Sep 2, 2016
2 parents 57aa885 + f5e80af commit dc43bfe
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 18 deletions.
50 changes: 32 additions & 18 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,17 +885,17 @@ with corresponding support (see `load`).
* `filename::String`: path to the binary file to write to.
* `data`: data to save to file. Data can be a`NDArray`, a `Vector{NDArray}`, or a `Dict{Base.Symbol, NDArray}`.
"""
function save(filename::AbstractString, data::NDArray)
function save(filename::String, data::NDArray)
save(filename, [data])
end
function save(filename::AbstractString, data::Vector{NDArray})
function save(filename::String, data::Vector{NDArray})
@mxcall(:MXNDArraySave, (char_p, MX_uint, Ptr{MX_handle}, char_pp),
filename, length(data), MX_handle[data...], char_pp(0))
end
function save(filename::AbstractString, data::Dict{Base.Symbol,NDArray})
function save(filename::String, data::Dict{Base.Symbol,NDArray})
names = [k for k in keys(data)]
arrays = MX_handle[data[k] for k in names]
names = AbstractString[string(k) for k in names]
names = String[string(k) for k in names]

@mxcall(:MXNDArraySave, (char_p, MX_uint, Ptr{MX_handle}, char_pp),
filename, length(names), arrays, names)
Expand All @@ -904,10 +904,12 @@ end
################################################################################
# NDArray functions dynamically imported from libmxnet
################################################################################
function _invoke_mxfunction(func_handle::MX_handle, use_vars, scalars, mut_vars)
@mxcall(:MXFuncInvoke,
(MX_handle, Ptr{MX_handle}, Ptr{MX_float}, Ptr{MX_handle}),
func_handle, use_vars, scalars, mut_vars)
function _invoke_mxfunction(func_handle::MX_handle, use_vars, scalars, mut_vars; kwargs...)
names = String[string(entry[1]) for entry in kwargs]
args = String[string(entry[2]) for entry in kwargs]
@mxcall(:MXFuncInvokeEx,
(MX_handle, Ptr{MX_handle}, Ptr{MX_float}, Ptr{MX_handle}, Cint, char_pp, char_pp),
func_handle, use_vars, scalars, mut_vars, length(names), names, args)
end

@enum(LIBMX_FUNC_TYPE_MASK,
Expand Down Expand Up @@ -1033,30 +1035,42 @@ function _get_function_expressions(handle :: MX_handle, name)
if name == :dot
_use_vars.args[2:end] = flipdim(_use_vars.args[2:end], 1)
end

if name == :transpose
transform = quote
kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs]
end
else
transform = :()
end

stmt_call = quote
local handle = _get_function($(QuoteNode(name)))
_invoke_mxfunction(handle, $_use_vars, $_scalars, $_mut_vars)
_invoke_mxfunction(handle, $_use_vars, $_scalars, $_mut_vars; kwargs...)
end
if n_mutate_vars == 1
stmt_ret = :(return out1)
else
stmt_ret = Expr(:return, Expr(:tuple, [Symbol("out$i") for i=1:n_mutate_vars]...))
end

func_body = Expr(:block, stmt_call, stmt_ret)
func_head = Expr(:call, name, args...)

func_def = Expr(:function, func_head, func_body)
func_def = quote
function $name($(args...); kwargs...)
$transform
$stmt_call
$stmt_ret
end
end

if accept_empty_mutate
args0 = args[1:n_used_vars+n_scalars]
func_head0 = Expr(:call, name, args0...)
_mut_vars0 = [:(NDArray(_ndarray_alloc())) for i=1:n_mutate_vars]
stmt_call0 = Expr(:call, name, args0..., _mut_vars0...)
func_body0 = Expr(:block, stmt_call0)
func_head0 = Expr(:call, name, args0...)

func_def0 = Expr(:function, func_head0, func_body0)
func_def0 = quote
function $name($(args0...); kwargs...)
$name($(args0...), $(_mut_vars0...); kwargs...)
end
end
return func_def, func_def0
else
return func_def, :()
Expand Down
5 changes: 5 additions & 0 deletions src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ function _define_atomic_symbol_creator(hdr :: MX_handle)
name = ""
end

if $func_name == :transpose
kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs]
end


param_keys = String[]
param_vals = String[]
symbol_kws = Dict{Symbol, SymbolicNode}()
Expand Down
13 changes: 13 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,18 @@ function test_eltype()
end
end

function test_kwargs()
info("NDArray::kwargs")
dims1 = (2,3,4)

A = rand(Float32, dims1)
x = mx.NDArray(A)
tx = mx.transpose(x, axes=(2,1,3))
tA = permutedims(A, [2,1,3])
@test size(tx) == size(tA)
@test all(copy(tx) .== tA)
end

################################################################################
# Run tests
################################################################################
Expand All @@ -315,5 +327,6 @@ test_sqrt()
test_eltype()
test_nd_as_jl()
test_dot()
test_kwargs()

end

0 comments on commit dc43bfe

Please sign in to comment.