Skip to content

Commit

Permalink
disambiguate nd op and sym op
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Sep 22, 2016
1 parent 03509ae commit 623fbaf
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 24 deletions.
16 changes: 11 additions & 5 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -963,8 +963,7 @@ function _get_ndarray_function_def(name :: String)
func_name = Symbol(name)

func_def = quote
function $func_name(args::NDArray...; out=nothing, kwargs...)
println($name)
function $func_name(::Type{NDArray}, args::NDArray...; out=nothing, kwargs...)
if out != nothing
output_vars = out
if isa(output_vars, NDArray)
Expand Down Expand Up @@ -1027,7 +1026,13 @@ function _get_ndarray_function_def(name :: String)
end
end

return func_def
func_def2 = quote
function $func_name(args::NDArray...; out=nothing, kwargs...)
$func_name(NDArray, args...; out=out, kwargs...)
end
end

return func_def, func_def2
end

macro _import_ndarray_functions()
Expand All @@ -1036,13 +1041,14 @@ macro _import_ndarray_functions()
op_handle = _get_libmx_op_handle(name)

desc, key_narg = _get_libmx_op_description(name, op_handle)
func_def = _get_ndarray_function_def(name)
func_def, func_def2 = _get_ndarray_function_def(name)

func_name = Symbol(name)
expr = quote
$(isdefined(Base, func_name) ? :(import Base.$func_name) : :())
@doc $desc ->
$func_def
@doc $desc ->
$func_def2
end
end

Expand Down
20 changes: 15 additions & 5 deletions src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,7 @@ function _define_atomic_symbol_creator(name :: String)

func_name = Symbol(name)
func_def = quote
@doc $f_desc ->
function $func_name(args::SymbolicNode...; kwargs...)
function $func_name(::Type{SymbolicNode}, args::SymbolicNode...; kwargs...)
idx = findfirst(x -> x[1] == :name, kwargs)
if idx > 0
name = kwargs[idx][2]
Expand Down Expand Up @@ -687,7 +686,18 @@ function _define_atomic_symbol_creator(name :: String)
return node
end # function
end # quote
return func_def

func_def2 = quote
@doc $f_desc ->
function $func_name(args::SymbolicNode...; kwargs...)
$func_name(SymbolicNode, args...; kwargs...)
end # function
end # quote

return quote
$func_def
$func_def2
end
end

macro _import_atomic_symbol_creators()
Expand All @@ -696,8 +706,8 @@ macro _import_atomic_symbol_creators()
# enough to disambiguate the method for NDArray and SymbolicNode
const ignored_ops = ["_set_value"]

names = _get_libmx_op_names()
func_exprs = map(names) do name
op_names = _get_libmx_op_names()
func_exprs = map(op_names) do name
if name ignored_ops
expr = _define_atomic_symbol_creator(name)
end
Expand Down
6 changes: 3 additions & 3 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ end

function mlp2()
data = mx.Variable(:data)
out = mx.FullyConnected(data=data, name=:fc1, num_hidden=1000)
out = mx.Activation(data=out, act_type=:relu)
out = mx.FullyConnected(data=out, name=:fc2, num_hidden=10)
out = mx.FullyConnected(data, name=:fc1, num_hidden=1000)
out = mx.Activation(out, act_type=:relu)
out = mx.FullyConnected(out, name=:fc2, num_hidden=10)
return out
end

21 changes: 10 additions & 11 deletions test/unittest/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ function test_internal()
info("SymbolicNode::internal")

data = mx.Variable(:data)
oldfc = mx.FullyConnected(data=data, name=:fc1, num_hidden=10)
net1 = mx.FullyConnected(data=oldfc, name=:fc2, num_hidden=100)
oldfc = mx.FullyConnected(data, name=:fc1, num_hidden=10)
net1 = mx.FullyConnected(oldfc, name=:fc2, num_hidden=100)

@test mx.list_arguments(net1) == [:data,:fc1_weight,:fc1_bias,:fc2_weight,:fc2_bias]

Expand All @@ -34,12 +34,12 @@ function test_compose()
info("SymbolicNode::compose")

data = mx.Variable(:data)
net1 = mx.FullyConnected(data=data, name=:fc1, num_hidden=10)
net1 = mx.FullyConnected(data=net1, name=:fc2, num_hidden=100)
net1 = mx.FullyConnected(data, name=:fc1, num_hidden=10)
net1 = mx.FullyConnected(net1, name=:fc2, num_hidden=100)

net2 = mx.FullyConnected(name=:fc3, num_hidden=10)
net2 = mx.Activation(data=net2, act_type=:relu)
net2 = mx.FullyConnected(data=net2, name=:fc4, num_hidden=20)
net2 = mx.FullyConnected(mx.SymbolicNode, name=:fc3, num_hidden=10)
net2 = mx.Activation(net2, act_type=:relu)
net2 = mx.FullyConnected(net2, name=:fc4, num_hidden=20)

composed = net2(fc3_data=net1, name=:composed)
multi_out = mx.Group(composed, net1)
Expand Down Expand Up @@ -96,14 +96,13 @@ function test_attrs()
data2 = mx.Variable(:data2, attrs = Dict(:test => "hallo!"))
@test get(mx.get_attr(data2, :test)) == "hallo!"

conv = mx.Convolution(data = data2, kernel = (1,1), num_filter = 1, attrs = Dict(:a => "a", => "π"))
conv = mx.Convolution(data2, kernel = (1,1), num_filter = 1, attrs = Dict(:a => "a", => "π"))
@test isnull(mx.get_attr(conv, :b))
@test get(mx.get_attr(conv, :a)) == "a"
@test get(mx.get_attr(conv, )) == "π"
@test mx.list_attr(conv) == Dict(:a => "a", => "π")

@test_throws MethodError mx.Variable(:data3, attrs = Dict(:test => "1.0", :test2 => 1.0))
@test_throws MethodError mx.Convolution(data=data2, kernel = (1,1), num_filter = 1, attrs = Dict(:test => "1.0", :test2 => 1.0))
@test_throws MethodError mx.Convolution(data2, kernel = (1,1), num_filter = 1, attrs = Dict(:test => "1.0", :test2 => 1.0))
end

function test_functions()
Expand All @@ -117,7 +116,7 @@ function test_dot()
x = mx.Variable(:x)
y = mx.Variable(:y)
z = mx.dot(x, y)
z_exec = mx.bind(z, context=mx.cpu(),
z_exec = mx.bind(z, context=mx.cpu(),
args=Dict(:x=>mx.ones((100, 2)), :y=>mx.ones((2, 200))))
mx.forward(z_exec)

Expand Down

0 comments on commit 623fbaf

Please sign in to comment.