Skip to content

Commit

Permalink
fix special handling of dot for SymbolicNode (apache#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Sep 7, 2016
1 parent 8949dbb commit f06d994
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,9 @@ function _get_function_expressions(handle :: MX_handle, name)
_use_vars.args[2:end] = flipdim(_use_vars.args[2:end], 1)
end

# XXX: hacky way of solving the semantic difference of the axes parameter in Julia
# and in libmxnet.
# See https://github.com/dmlc/MXNet.jl/pull/123
if name == :transpose
transform = quote
kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs]
Expand Down
12 changes: 10 additions & 2 deletions src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,20 @@ function _define_atomic_symbol_creator(hdr :: MX_handle)
else
name = ""
end

# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
if $func_name_s == "dot"
args = reverse(args)
end

if $func_name == :transpose
# XXX: hacky way of solving the semantic difference of the axes parameter in Julia
# and in libmxnet.
# See https://github.com/dmlc/MXNet.jl/pull/123
if $func_name_s == "transpose"
kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs]
end


param_keys = String[]
param_vals = String[]
symbol_kws = Dict{Symbol, SymbolicNode}()
Expand Down
17 changes: 16 additions & 1 deletion test/unittest/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module TestSymbolicNode
using MXNet
using Base.Test

using ..Main: mlp2
using ..Main: mlp2, reldiff

################################################################################
# Test Implementations
Expand Down Expand Up @@ -112,6 +112,20 @@ function test_functions()
typeof(mx.sum(data)) == mx.SymbolicNode
end

function test_dot()
info("SymbolicNode::dot")
x = mx.Variable(:x)
y = mx.Variable(:y)
z = mx.dot(x, y)
z_exec = mx.bind(z, context=mx.cpu(),
args=Dict(:x=>mx.ones((100, 2)), :y=>mx.ones((2, 200))))
mx.forward(z_exec)

ret = copy(z_exec.outputs[1])
@test size(ret) == (100, 200)
@test reldiff(ret, 2*ones(100, 200)) < 1e-6
end

################################################################################
# Run tests
################################################################################
Expand All @@ -123,5 +137,6 @@ test_infer_shape_error()
test_saveload()
test_attrs()
test_functions()
test_dot()

end

0 comments on commit f06d994

Please sign in to comment.