Skip to content

Commit

Permalink
Speech (#3620)
Browse files Browse the repository at this point in the history
* For bucketing, set the learning rate to per sample.

* Change default clipping to 0.

* Adding peephole; Fix softmax error.
  • Loading branch information
pluskid authored and piiswrong committed Oct 25, 2016
1 parent 8413b28 commit 6ba8224
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
2 changes: 1 addition & 1 deletion example/speech-demo/default.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ optimizer = speechSGD
momentum = 0.9

# set to 0 to disable gradient clipping
clip_gradient = 1
clip_gradient = 0

# uniform, normal, xavier
initializer = Uniform
Expand Down
26 changes: 18 additions & 8 deletions example/speech-demo/lstm_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias",
"ph2h_weight"
])
"ph2h_weight",
"c2i_bias", "c2f_bias", "c2o_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states",
"seq_data", "seq_labels", "seq_outputs",
Expand All @@ -32,11 +32,18 @@ def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., nu
gates = i2h + h2h
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
name="t%d_l%d_slice" % (seqidx, layeridx))
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")

Wcidc = mx.sym.broadcast_mul(param.c2i_bias, prev_state.c) + slice_gates[0]
in_gate = mx.sym.Activation(Wcidc, act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")

Wcfdc = mx.sym.broadcast_mul(param.c2f_bias, prev_state.c) + slice_gates[2]
forget_gate = mx.sym.Activation(Wcfdc, act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)

Wcoct = mx.sym.broadcast_mul(param.c2o_bias, next_c) + slice_gates[3]
out_gate = mx.sym.Activation(Wcoct, act_type="sigmoid")

next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")

if num_hidden_proj > 0:
Expand All @@ -62,7 +69,10 @@ def lstm_unroll(num_lstm_layer, seq_len, input_size,
i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i),
ph2h_weight = mx.sym.Variable("l%d_ph2h_weight" % i)
ph2h_weight = mx.sym.Variable("l%d_ph2h_weight" % i),
c2i_bias = mx.sym.Variable("l%d_c2i_bias" % i, shape=(1,num_hidden)),
c2f_bias = mx.sym.Variable("l%d_c2f_bias" % i, shape=(1,num_hidden)),
c2o_bias = mx.sym.Variable("l%d_c2o_bias" % i, shape=(1, num_hidden))
))
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
h=mx.sym.Variable("l%d_init_h" % i))
Expand Down Expand Up @@ -102,8 +112,8 @@ def lstm_unroll(num_lstm_layer, seq_len, input_size,
hidden_final = mx.sym.Reshape(hidden_concat, target_shape=(0, num_hidden))
pred = mx.sym.FullyConnected(data=hidden_final, num_hidden=num_label,
weight=cls_weight, bias=cls_bias, name='pred')
pred = mx.sym.Reshape(pred, target_shape=(0, seq_len, num_label))

pred = mx.sym.Reshape(pred, shape=(-1, num_label))
label = mx.sym.Reshape(label, shape=(-1,))
if take_softmax:
sm = mx.sym.SoftmaxOutput(data=pred, label=label, ignore_label=0,
use_ignore=True, name='softmax')
Expand Down
10 changes: 5 additions & 5 deletions example/speech-demo/train_lstm_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def prepare_data(args):

def CrossEntropy(labels, preds):
labels = labels.reshape((-1,))
preds = preds.reshape((-1, preds.shape[2]))
preds = preds.reshape((-1, preds.shape[1]))
loss = 0.
num_inst = 0
for i in range(preds.shape[0]):
Expand All @@ -70,7 +70,7 @@ def CrossEntropy(labels, preds):

def Acc_exclude_padding(labels, preds):
labels = labels.reshape((-1,))
preds = preds.reshape((-1, preds.shape[2]))
preds = preds.reshape((-1, preds.shape[1]))
sum_metric = 0
num_inst = 0
for i in range(preds.shape[0]):
Expand Down Expand Up @@ -163,7 +163,7 @@ def do_training(training_method, args, module, data_train, data_val):

def reset_optimizer():
if optimizer == "sgd" or optimizer == "speechSGD":
module.init_optimizer(kvstore='local',
module.init_optimizer(kvstore='device',
optimizer=args.config.get('train', 'optimizer'),
optimizer_params={'lr_scheduler': lr_scheduler,
'momentum': momentum,
Expand All @@ -172,7 +172,7 @@ def reset_optimizer():
'wd': weight_decay},
force_init=True)
else:
module.init_optimizer(kvstore='local',
module.init_optimizer(kvstore='device',
optimizer=args.config.get('train', 'optimizer'),
optimizer_params={'lr_scheduler': lr_scheduler,
'rescale_grad': 1.0,
Expand All @@ -191,7 +191,7 @@ def reset_optimizer():
lr_scheduler.momentum = np.power(np.power(momentum, 1.0/(data_train.batch_size * truncate_len)), data_batch.effective_sample_count)
else:
if data_batch.effective_sample_count is not None:
lr_scheduler.effective_sample_count = data_batch.effective_sample_count
lr_scheduler.effective_sample_count = 1#data_batch.effective_sample_count

module.forward_backward(data_batch)
module.update()
Expand Down

0 comments on commit 6ba8224

Please sign in to comment.