From f06d9943f51fdb77b8c2841aa70258eb4dbf7f30 Mon Sep 17 00:00:00 2001 From: Chiyuan Zhang Date: Tue, 6 Sep 2016 23:22:48 -0400 Subject: [PATCH] fix special handling of dot for SymbolicNode (#123) --- src/ndarray.jl | 3 +++ src/symbolic-node.jl | 12 ++++++++++-- test/unittest/symbolic-node.jl | 17 ++++++++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index c3288dc323eb..2e2c806552f7 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -1036,6 +1036,9 @@ function _get_function_expressions(handle :: MX_handle, name) _use_vars.args[2:end] = flipdim(_use_vars.args[2:end], 1) end + # XXX: hacky way of solving the semantic difference of the axes parameter in Julia + # and in libmxnet. + # See https://github.com/dmlc/MXNet.jl/pull/123 if name == :transpose transform = quote kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs] diff --git a/src/symbolic-node.jl b/src/symbolic-node.jl index 15ae1d7d0e2d..dfc54c3c3b1c 100644 --- a/src/symbolic-node.jl +++ b/src/symbolic-node.jl @@ -629,12 +629,20 @@ function _define_atomic_symbol_creator(hdr :: MX_handle) else name = "" end + + # XXX: hacky way of solving the problem that the arguments of `dot` should be swapped + # See https://github.com/dmlc/MXNet.jl/issues/55 + if $func_name_s == "dot" + args = reverse(args) + end - if $func_name == :transpose + # XXX: hacky way of solving the semantic difference of the axes parameter in Julia + # and in libmxnet. + # See https://github.com/dmlc/MXNet.jl/pull/123 + if $func_name_s == "transpose" kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs] end - param_keys = String[] param_vals = String[] symbol_kws = Dict{Symbol, SymbolicNode}() diff --git a/test/unittest/symbolic-node.jl b/test/unittest/symbolic-node.jl index 388a74fe644b..d78b0775a983 100644 --- a/test/unittest/symbolic-node.jl +++ b/test/unittest/symbolic-node.jl @@ -2,7 +2,7 @@ module TestSymbolicNode using MXNet using Base.Test -using ..Main: mlp2 +using ..Main: mlp2, reldiff ################################################################################ # Test Implementations @@ -112,6 +112,20 @@ function test_functions() typeof(mx.sum(data)) == mx.SymbolicNode end +function test_dot() + info("SymbolicNode::dot") + x = mx.Variable(:x) + y = mx.Variable(:y) + z = mx.dot(x, y) + z_exec = mx.bind(z, context=mx.cpu(), + args=Dict(:x=>mx.ones((100, 2)), :y=>mx.ones((2, 200)))) + mx.forward(z_exec) + + ret = copy(z_exec.outputs[1]) + @test size(ret) == (100, 200) + @test reldiff(ret, 2*ones(100, 200)) < 1e-6 +end + ################################################################################ # Run tests ################################################################################ @@ -123,5 +137,6 @@ test_infer_shape_error() test_saveload() test_attrs() test_functions() +test_dot() end