Skip to content

Commit

Permalink
change examples for new symbolic calling conventions (nnvm, apache#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Sep 26, 2016
1 parent d718cfc commit 590055b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
8 changes: 4 additions & 4 deletions examples/char-lstm/lstm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ function lstm_cell(data::mx.SymbolicNode, prev_state::LSTMState, param::LSTMPara
data = mx.Dropout(data, p=dropout)
end

i2h = mx.FullyConnected(data=data, weight=param.i2h_W, bias=param.i2h_b,
i2h = mx.FullyConnected(data, weight=param.i2h_W, bias=param.i2h_b,
num_hidden=4num_hidden, name=symbol(name, "_i2h"))
h2h = mx.FullyConnected(data=prev_state.h, weight=param.h2h_W, bias=param.h2h_b,
h2h = mx.FullyConnected(prev_state.h, weight=param.h2h_W, bias=param.h2h_b,
num_hidden=4num_hidden, name=symbol(name, "_h2h"))

gates = mx.SliceChannel(i2h + h2h, num_outputs=4, name=symbol(name, "_gates"))
Expand Down Expand Up @@ -71,7 +71,7 @@ function LSTM(n_layer::Int, seq_len::Int, dim_hidden::Int, dim_embed::Int, n_cla
for t = 1:seq_len
data = mx.Variable(symbol(name, "_data_$t"))
label = mx.Variable(symbol(name, "_label_$t"))
hidden = mx.FullyConnected(data=data, weight=embed_W, num_hidden=dim_embed,
hidden = mx.FullyConnected(data, weight=embed_W, num_hidden=dim_embed,
no_bias=true, name=symbol(name, "_embed_$t"))

# stack LSTM cells
Expand All @@ -88,7 +88,7 @@ function LSTM(n_layer::Int, seq_len::Int, dim_hidden::Int, dim_embed::Int, n_cla
if dropout > 0
hidden = mx.Dropout(hidden, p=dropout)
end
pred = mx.FullyConnected(data=hidden, weight=pred_W, bias=pred_b, num_hidden=n_class,
pred = mx.FullyConnected(hidden, weight=pred_W, bias=pred_b, num_hidden=n_class,
name=symbol(name, "_pred_$t"))
smax = mx.SoftmaxOutput(pred, label, name=symbol(name, "_softmax_$t"))
push!(outputs, smax)
Expand Down
16 changes: 8 additions & 8 deletions examples/cifar10/cifar10.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ using MXNet

# basic Conv + BN + ReLU factory
function conv_factory(data, num_filter, kernel; stride=(1,1), pad=(0,0), act_type=:relu)
conv = mx.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
bn = mx.BatchNorm(data=conv)
act = mx.Activation(data=bn, act_type=act_type)
conv = mx.Convolution(data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
bn = mx.BatchNorm(conv)
act = mx.Activation(bn, act_type=act_type)
return act
end

Expand All @@ -16,7 +16,7 @@ function downsample_factory(data, ch_3x3)
# conv 3x3
conv = conv_factory(data, ch_3x3, (3,3), stride=(2,2), pad=(1,1))
# pool
pool = mx.Pooling(data=data, kernel=(3,3), stride=(2,2), pool_type=:max)
pool = mx.Pooling(data, kernel=(3,3), stride=(2,2), pool_type=:max)
# concat
concat = mx.Concat(conv, pool)
return concat
Expand Down Expand Up @@ -48,10 +48,10 @@ in4d = simple_factory(in4b, 48, 96)
in4e = downsample_factory(in4d, 96)
in5a = simple_factory(in4e, 176, 160)
in5b = simple_factory(in5a, 176, 160)
pool = mx.Pooling(data=in5b, pool_type=:avg, kernel=(7,7), name=:global_pool)
flatten = mx.Flatten(data=pool, name=:flatten1)
fc = mx.FullyConnected(data=flatten, num_hidden=10, name=:fc1)
softmax = mx.SoftmaxOutput(data=fc, name=:loss)
pool = mx.Pooling(in5b, pool_type=:avg, kernel=(7,7), name=:global_pool)
flatten = mx.Flatten(pool, name=:flatten1)
fc = mx.FullyConnected(flatten, num_hidden=10, name=:fc1)
softmax = mx.SoftmaxOutput(fc, name=:loss)


#--------------------------------------------------------------------------------
Expand Down
10 changes: 5 additions & 5 deletions examples/mnist/lenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@ using MXNet
data = mx.Variable(:data)

# first conv
conv1 = @mx.chain mx.Convolution(data=data, kernel=(5,5), num_filter=20) =>
conv1 = @mx.chain mx.Convolution(data, kernel=(5,5), num_filter=20) =>
mx.Activation(act_type=:tanh) =>
mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))

# second conv
conv2 = @mx.chain mx.Convolution(data=conv1, kernel=(5,5), num_filter=50) =>
conv2 = @mx.chain mx.Convolution(conv1, kernel=(5,5), num_filter=50) =>
mx.Activation(act_type=:tanh) =>
mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))

# first fully-connected
fc1 = @mx.chain mx.Flatten(data=conv2) =>
fc1 = @mx.chain mx.Flatten(conv2) =>
mx.FullyConnected(num_hidden=500) =>
mx.Activation(act_type=:tanh)

# second fully-connected
fc2 = mx.FullyConnected(data=fc1, num_hidden=10)
fc2 = mx.FullyConnected(fc1, num_hidden=10)

# softmax loss
lenet = mx.SoftmaxOutput(data=fc2, name=:softmax)
lenet = mx.SoftmaxOutput(fc2, name=:softmax)


#--------------------------------------------------------------------------------
Expand Down
12 changes: 6 additions & 6 deletions examples/mnist/mlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ using MXNet

#-- Option 1: explicit composition
# data = mx.Variable(:data)
# fc1 = mx.FullyConnected(data = data, name=:fc1, num_hidden=128)
# act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu)
# fc2 = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64)
# act2 = mx.Activation(data = fc2, name=:relu2, act_type=:relu)
# fc3 = mx.FullyConnected(data = act2, name=:fc3, num_hidden=10)
# mlp = mx.SoftmaxOutput(data = fc3, name=:softmax)
# fc1 = mx.FullyConnected(data, name=:fc1, num_hidden=128)
# act1 = mx.Activation(fc1, name=:relu1, act_type=:relu)
# fc2 = mx.FullyConnected(act1, name=:fc2, num_hidden=64)
# act2 = mx.Activation(fc2, name=:relu2, act_type=:relu)
# fc3 = mx.FullyConnected(act2, name=:fc3, num_hidden=10)
# mlp = mx.SoftmaxOutput(fc3, name=:softmax)

#-- Option 2: using the mx.chain macro
# mlp = @mx.chain mx.Variable(:data) =>
Expand Down

0 comments on commit 590055b

Please sign in to comment.