From a81a1ef8e9e839c9c50bdc5fae69afbeffb46036 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 10 Nov 2018 16:11:14 +0100 Subject: [PATCH] fixing learning rate schedule when using gradient_accumulation_steps --- run_classifier.py | 2 +- run_squad.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/run_classifier.py b/run_classifier.py index ab5251b1c06071..b9aafce645793c 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -464,7 +464,7 @@ def main(): if args.do_train: train_examples = processor.get_train_examples(args.data_dir) num_train_steps = int( - len(train_examples) / args.train_batch_size * args.num_train_epochs) + len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) model = BertForSequenceClassification(bert_config, len(label_list)) if args.init_checkpoint is not None: diff --git a/run_squad.py b/run_squad.py index e44044f9a08887..9a9fbb61d59938 100644 --- a/run_squad.py +++ b/run_squad.py @@ -742,6 +742,10 @@ def main(): default=False, action='store_true', help="Whether to perform optimization and keep the optimizer averages on CPU") + parser.add_argument('--fp16', + default=False, + action='store_true', + help="Whether to use 16-bit float precision instead of 32-bit") args = parser.parse_args() @@ -801,11 +805,13 @@ def main(): train_examples = read_squad_examples( input_file=args.train_file, is_training=True) num_train_steps = int( - len(train_examples) / args.train_batch_size * args.num_train_epochs) + len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) model = BertForQuestionAnswering(bert_config) if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) + if args.fp16: + model.half() if not args.optimize_on_cpu: model.to(device) @@ -847,6 +853,12 @@ def main(): all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) + if args.fp16: + (all_input_ids, all_input_mask, + all_segment_ids, all_start_positions, + all_end_positions) = tuple(t.half() for t in (all_input_ids, all_input_mask, all_segment_ids, + all_start_positions, all_end_positions)) + train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) if args.local_rank == -1: @@ -895,6 +907,10 @@ def main(): all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) + if args.fp16: + (all_input_ids, all_input_mask, + all_segment_ids, all_example_index) = tuple(t.half() for t in (all_input_ids, all_input_mask, + all_segment_ids, all_example_index)) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) if args.local_rank == -1: