Skip to content

Commit 72edc94

Browse files
author
Richardfan
committed
change a better way to adjust learning rate
1 parent dba622c commit 72edc94

File tree

4 files changed

+24
-16
lines changed

4 files changed

+24
-16
lines changed

my_863_corpus/steps/cnn_lstm_ctc.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def main():
175175
count = 0
176176
learning_rate = init_lr
177177
acc_best = -100
178+
acc_best_true = -100
178179
adjust_rate_flag = False
179180
stop_train = False
180181
adjust_time = 0
@@ -216,8 +217,8 @@ def main():
216217
op_state = copy.deepcopy(optimizer.state_dict())
217218
elif (acc > acc_best - end_adjust_acc):
218219
adjust_rate_count += 1
219-
if acc > acc_best and adjust_rate_count == 10:
220-
acc_best = acc
220+
if acc > acc_best and acc > acc_best_true:
221+
acc_best_true = acc
221222
model_state = copy.deepcopy(model.state_dict())
222223
op_state = copy.deepcopy(optimizer.state_dict())
223224
else:
@@ -232,10 +233,11 @@ def main():
232233
adjust_rate_flag = True
233234
adjust_time += 1
234235
adjust_rate_count = 0
235-
236-
if adjust_time == 8:
236+
acc_best = acc_best_true
237237
model.load_state_dict(model_state)
238238
optimizer.load_state_dict(op_state)
239+
240+
if adjust_time == 8:
239241
stop_train = True
240242

241243
time_used = (time.time() - start_time) / 60

my_863_corpus/steps/lstm_ctc.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def main():
183183
count = 0
184184
learning_rate = init_lr
185185
acc_best = -100
186+
acc_best_true = -100
186187
adjust_rate_count = 0
187188
adjust_rate_flag = False
188189
stop_train = False
@@ -225,8 +226,8 @@ def main():
225226
op_state = copy.deepcopy(optimizer.state_dict())
226227
elif (acc > acc_best - end_adjust_acc):
227228
adjust_rate_count += 1
228-
if acc > acc_best and adjust_rate_count == 10:
229-
acc_best = acc
229+
if acc > acc_best and acc > acc_best_true:
230+
acc_best_true = acc
230231
model_state = copy.deepcopy(model.state_dict())
231232
op_state = copy.deepcopy(optimizer.state_dict())
232233
else:
@@ -241,10 +242,11 @@ def main():
241242
adjust_rate_flag = True
242243
adjust_time += 1
243244
adjust_rate_count = 0
244-
245-
if adjust_time == 8:
245+
acc_best = acc_best_true
246246
model.load_state_dict(model_state)
247247
optimizer.load_state_dict(op_state)
248+
249+
if adjust_time == 8:
248250
stop_train = True
249251

250252
time_used = (time.time() - start_time) / 60

timit/steps/cnn_lstm_ctc.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def main():
175175
count = 0
176176
learning_rate = init_lr
177177
acc_best = -100
178+
acc_best_true = -100
178179
adjust_rate_flag = False
179180
stop_train = False
180181
adjust_time = 0
@@ -216,8 +217,8 @@ def main():
216217
op_state = copy.deepcopy(optimizer.state_dict())
217218
elif (acc > acc_best - end_adjust_acc):
218219
adjust_rate_count += 1
219-
if acc > acc_best and adjust_rate_count == 10:
220-
acc_best = acc
220+
if acc > acc_best and acc > acc_best_true:
221+
acc_best_true = acc
221222
model_state = copy.deepcopy(model.state_dict())
222223
op_state = copy.deepcopy(optimizer.state_dict())
223224
else:
@@ -233,10 +234,11 @@ def main():
233234
adjust_rate_flag = True
234235
adjust_time += 1
235236
adjust_rate_count = 0
236-
237-
if adjust_time == 8:
237+
acc_best = acc_best_true
238238
model.load_state_dict(model_state)
239239
optimizer.load_state_dict(op_state)
240+
241+
if adjust_time == 8:
240242
stop_train = True
241243

242244
time_used = (time.time() - start_time) / 60

timit/steps/lstm_ctc.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def main():
183183
count = 0
184184
learning_rate = init_lr
185185
acc_best = -100
186+
acc_best_true = -100
186187
adjust_rate_flag = False
187188
stop_train = False
188189
adjust_time = 0
@@ -224,8 +225,8 @@ def main():
224225
op_state = copy.deepcopy(optimizer.state_dict())
225226
elif (acc > acc_best - end_adjust_acc):
226227
adjust_rate_count += 1
227-
if acc > acc_best and adjust_rate_count == 10:
228-
acc_best = acc
228+
if acc > acc_best and acc > acc_best_true:
229+
acc_best_true = acc
229230
model_state = copy.deepcopy(model.state_dict())
230231
op_state = copy.deepcopy(optimizer.state_dict())
231232
else:
@@ -241,10 +242,11 @@ def main():
241242
adjust_rate_flag = True
242243
adjust_time += 1
243244
adjust_rate_count = 0
244-
245-
if adjust_time == 8:
245+
acc_best = acc_best_true
246246
model.load_state_dict(model_state)
247247
optimizer.load_state_dict(op_state)
248+
249+
if adjust_time == 8:
248250
stop_train = True
249251

250252
time_used = (time.time() - start_time) / 60

0 commit comments

Comments
 (0)