Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions deep_speech_2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from model import deep_speech2
from audio_data_utils import DataGenerator
import numpy as np
import os

#TODO: add WER metric

Expand Down Expand Up @@ -78,6 +79,11 @@
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--init_model_path",
default='models/params.tar.gz',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set default to None (training from scratch).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

type=str,
help="Model path for initialization. (default: %(default)s)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add e.g. "If set None, the training will start from scratch. Otherwise, the training will resume from the existing model of this path".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

args = parser.parse_args()


Expand Down Expand Up @@ -114,8 +120,13 @@ def train():
rnn_size=args.rnn_layer_size,
is_inference=False)

# create parameters and optimizer
parameters = paddle.parameters.create(cost)
# create/load parameters and optimizer
if args.init_model_path is None:
parameters = paddle.parameters.create(cost)
else:
assert os.path.isfile(args.init_model_path), "Invalid model."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to use "raise IOError" ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree and done.

parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.init_model_path))
optimizer = paddle.optimizer.Adam(
learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
trainer = paddle.trainer.SGD(
Expand Down