Skip to content

Commit ad88003

Browse files
authored
[BERT/PyT] Glue(MRPC) fine-tuning with LAMB pretrained checkpoint
* LAMB checkpoint compatibility * LAMB checkpoint compatibility; amp training
1 parent 119838f commit ad88003

File tree

2 files changed

+190
-88
lines changed

2 files changed

+190
-88
lines changed

PyTorch/LanguageModeling/BERT/run_glue.py

Lines changed: 158 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# coding=utf-8
2-
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
32
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4+
#
45
# Licensed under the Apache License, Version 2.0 (the "License");
56
# you may not use this file except in compliance with the License.
67
# You may obtain a copy of the License at
@@ -12,11 +13,11 @@
1213
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1314
# See the License for the specific language governing permissions and
1415
# limitations under the License.
15-
1616
"""BERT finetuning runner."""
1717

1818
from __future__ import absolute_import, division, print_function
1919

20+
import pickle
2021
import argparse
2122
import csv
2223
import logging
@@ -35,12 +36,53 @@
3536
from modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
3637
from tokenization import BertTokenizer
3738
from optimization import BertAdam, warmup_linear
39+
from schedulers import LinearWarmUpScheduler
40+
from apex import amp
41+
from sklearn.metrics import matthews_corrcoef, f1_score
42+
from utils import is_main_process
3843

3944
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
4045
datefmt = '%m/%d/%Y %H:%M:%S',
4146
level = logging.INFO)
4247
logger = logging.getLogger(__name__)
4348

49+
def compute_metrics(task_name, preds, labels):
50+
assert len(preds) == len(labels)
51+
if task_name == "cola":
52+
return {"mcc": matthews_corrcoef(labels, preds)}
53+
elif task_name == "sst-2":
54+
return {"acc": simple_accuracy(preds, labels)}
55+
elif task_name == "mrpc":
56+
return acc_and_f1(preds, labels)
57+
elif task_name == "sts-b":
58+
return pearson_and_spearman(preds, labels)
59+
elif task_name == "qqp":
60+
return acc_and_f1(preds, labels)
61+
elif task_name == "mnli":
62+
return {"acc": simple_accuracy(preds, labels)}
63+
elif task_name == "mnli-mm":
64+
return {"acc": simple_accuracy(preds, labels)}
65+
elif task_name == "qnli":
66+
return {"acc": simple_accuracy(preds, labels)}
67+
elif task_name == "rte":
68+
return {"acc": simple_accuracy(preds, labels)}
69+
elif task_name == "wnli":
70+
return {"acc": simple_accuracy(preds, labels)}
71+
else:
72+
raise KeyError(task_name)
73+
74+
75+
def simple_accuracy(preds, labels):
76+
return (preds == labels).mean()
77+
78+
def acc_and_f1(preds, labels):
79+
acc = simple_accuracy(preds, labels)
80+
f1 = f1_score(y_true=labels, y_pred=preds)
81+
return {
82+
"acc": acc,
83+
"f1": f1,
84+
"acc_and_f1": (acc + f1) / 2,
85+
}
4486

4587
class InputExample(object):
4688
"""A single training/test example for simple sequence classification."""
@@ -298,6 +340,30 @@ def accuracy(out, labels):
298340
outputs = np.argmax(out, axis=1)
299341
return np.sum(outputs == labels)
300342

343+
from apex.multi_tensor_apply import multi_tensor_applier
344+
class GradientClipper:
345+
"""
346+
Clips gradient norm of an iterable of parameters.
347+
"""
348+
def __init__(self, max_grad_norm):
349+
self.max_norm = max_grad_norm
350+
if multi_tensor_applier.available:
351+
import amp_C
352+
self._overflow_buf = torch.cuda.IntTensor([0])
353+
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
354+
self.multi_tensor_scale = amp_C.multi_tensor_scale
355+
else:
356+
raise RuntimeError('Gradient clipping requires cuda extensions')
357+
358+
def step(self, parameters):
359+
l = [p.grad for p in parameters if p.grad is not None]
360+
total_norm, _ = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [l], False)
361+
total_norm = total_norm.item()
362+
if (total_norm == float('inf')): return
363+
clip_coef = self.max_norm / (total_norm + 1e-6)
364+
if clip_coef < 1:
365+
multi_tensor_applier(self.multi_tensor_scale, self._overflow_buf, [l, l], clip_coef)
366+
301367
def main():
302368
parser = argparse.ArgumentParser()
303369

@@ -363,6 +429,9 @@ def main():
363429
default=3.0,
364430
type=float,
365431
help="Total number of training epochs to perform.")
432+
parser.add_argument("--google_pretrained",
433+
action='store_true',
434+
help="Whether not to use CUDA when available")
366435
parser.add_argument("--max_steps", default=-1.0, type=float,
367436
help="Total number of training steps to perform.")
368437
parser.add_argument("--warmup_proportion",
@@ -379,7 +448,7 @@ def main():
379448
help="local_rank for distributed training on gpus")
380449
parser.add_argument('--seed',
381450
type=int,
382-
default=42,
451+
default=1,
383452
help="random seed for initialization")
384453
parser.add_argument('--gradient_accumulation_steps',
385454
type=int,
@@ -395,6 +464,16 @@ def main():
395464
"Positive power of 2: static loss scaling value.\n")
396465
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
397466
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
467+
parser.add_argument("--old", action='store_true', help="use old fp16 optimizer")
468+
parser.add_argument('--vocab_file',
469+
type=str, default=None, required=True,
470+
help="Vocabulary mapping/file BERT was pretrainined on")
471+
parser.add_argument("--config_file",
472+
default=None,
473+
type=str,
474+
required=True,
475+
help="The BERT model config")
476+
398477
args = parser.parse_args()
399478

400479
if args.server_ip and args.server_port:
@@ -445,7 +524,7 @@ def main():
445524

446525
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
447526
print("WARNING: Output directory ({}) already exists and is not empty.".format(args.output_dir))
448-
if not os.path.exists(args.output_dir):
527+
if not os.path.exists(args.output_dir) and is_main_process():
449528
os.makedirs(args.output_dir)
450529

451530
task_name = args.task_name.lower()
@@ -457,8 +536,9 @@ def main():
457536
num_labels = num_labels_task[task_name]
458537
label_list = processor.get_labels()
459538

460-
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
461-
539+
#tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
540+
tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large
541+
462542
train_examples = None
463543
num_train_optimization_steps = None
464544
if args.do_train:
@@ -469,25 +549,17 @@ def main():
469549
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
470550

471551
# Prepare model
472-
cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
473-
model = BertForSequenceClassification.from_pretrained(args.bert_model,
474-
cache_dir=cache_dir,
475-
num_labels = num_labels)
476-
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'), strict=False)
552+
config = BertConfig.from_json_file(args.config_file)
553+
# Padding for divisibility by 8
554+
if config.vocab_size % 8 != 0:
555+
config.vocab_size += 8 - (config.vocab_size % 8)
477556

478-
if args.fp16:
479-
model.half()
480-
model.to(device)
481-
if args.local_rank != -1:
482-
try:
483-
from apex.parallel import DistributedDataParallel as DDP
484-
except ImportError:
485-
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
486-
487-
model = DDP(model)
488-
elif n_gpu > 1:
489-
model = torch.nn.DataParallel(model)
557+
model = BertForSequenceClassification(config, num_labels=num_labels)
558+
print("USING CHECKPOINT from", args.init_checkpoint)
559+
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')["model"], strict=False)
560+
print("USED CHECKPOINT from", args.init_checkpoint)
490561

562+
model.to(device)
491563
# Prepare optimizer
492564
param_optimizer = list(model.named_parameters())
493565
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
@@ -496,33 +568,63 @@ def main():
496568
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
497569
]
498570
if args.fp16:
571+
print("using fp16")
499572
try:
500-
from apex.contrib.optimizers import FP16_Optimizer
501573
from apex.optimizers import FusedAdam
502574
except ImportError:
503575
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
504576

505577
optimizer = FusedAdam(optimizer_grouped_parameters,
506578
lr=args.learning_rate,
507-
bias_correction=False,
508-
max_grad_norm=1.0)
579+
bias_correction=False)
580+
509581
if args.loss_scale == 0:
510-
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
582+
583+
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False,
584+
loss_scale="dynamic")
511585
else:
512-
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
586+
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale=args.loss_scale)
587+
scheduler = LinearWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=num_train_optimization_steps)
513588

589+
590+
514591
else:
592+
print("using fp32")
515593
optimizer = BertAdam(optimizer_grouped_parameters,
516594
lr=args.learning_rate,
517595
warmup=args.warmup_proportion,
518596
t_total=num_train_optimization_steps)
519597

598+
if args.local_rank != -1:
599+
try:
600+
from apex.parallel import DistributedDataParallel as DDP
601+
except ImportError:
602+
raise ImportError(
603+
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
604+
605+
model = DDP(model)
606+
elif n_gpu > 1:
607+
model = torch.nn.DataParallel(model)
608+
520609
global_step = 0
521610
nb_tr_steps = 0
522611
tr_loss = 0
523612
if args.do_train:
524-
train_features = convert_examples_to_features(
525-
train_examples, label_list, args.max_seq_length, tokenizer)
613+
print("data prep")
614+
cached_train_features_file = args.data_dir + '_{0}_{1}_{2}'.format(
615+
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.do_lower_case))
616+
train_features = None
617+
618+
try:
619+
with open(cached_train_features_file, "rb") as reader:
620+
train_features = pickle.load(reader)
621+
except:
622+
train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer)
623+
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
624+
logger.info(" Saving train features into cached file %s", cached_train_features_file)
625+
with open(cached_train_features_file, "wb") as writer:
626+
pickle.dump(train_features, writer)
627+
526628
logger.info("***** Running training *****")
527629
logger.info(" Num examples = %d", len(train_examples))
528630
logger.info(" Batch size = %d", args.train_batch_size)
@@ -554,42 +656,23 @@ def main():
554656
loss = loss / args.gradient_accumulation_steps
555657

556658
if args.fp16:
557-
optimizer.backward(loss)
659+
with amp.scale_loss(loss, optimizer) as scaled_loss:
660+
scaled_loss.backward()
558661
else:
559662
loss.backward()
560663

561664
tr_loss += loss.item()
562665
nb_tr_examples += input_ids.size(0)
563666
nb_tr_steps += 1
564667
if (step + 1) % args.gradient_accumulation_steps == 0:
565-
if args.fp16:
566-
# modify learning rate with special warm up BERT uses
567-
# if args.fp16 is False, BertAdam is used that handles this automatically
568-
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
569-
for param_group in optimizer.param_groups:
570-
param_group['lr'] = lr_this_step
668+
if args.fp16 :
669+
# modify learning rate with special warm up for BERT which FusedAdam doesn't do
670+
scheduler.step()
671+
571672
optimizer.step()
572673
optimizer.zero_grad()
573674
global_step += 1
574675

575-
if args.do_train:
576-
# Save a trained model and the associated configuration
577-
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
578-
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
579-
torch.save(model_to_save.state_dict(), output_model_file)
580-
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
581-
with open(output_config_file, 'w') as f:
582-
f.write(model_to_save.config.to_json_string())
583-
584-
# Load a trained model and config that you have fine-tuned
585-
config = BertConfig(output_config_file)
586-
model = BertForSequenceClassification(config, num_labels=num_labels)
587-
model.load_state_dict(torch.load(output_model_file))
588-
else:
589-
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
590-
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'), strict=False)
591-
model.to(device)
592-
593676
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
594677
eval_examples = processor.get_dev_examples(args.data_dir)
595678
eval_features = convert_examples_to_features(
@@ -609,7 +692,8 @@ def main():
609692
model.eval()
610693
eval_loss, eval_accuracy = 0, 0
611694
nb_eval_steps, nb_eval_examples = 0, 0
612-
695+
preds = None
696+
out_label_ids = None
613697
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
614698
input_ids = input_ids.to(device)
615699
input_mask = input_mask.to(device)
@@ -620,30 +704,35 @@ def main():
620704
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
621705
logits = model(input_ids, segment_ids, input_mask)
622706

623-
logits = logits.detach().cpu().numpy()
624-
label_ids = label_ids.to('cpu').numpy()
625-
tmp_eval_accuracy = accuracy(logits, label_ids)
626-
627-
eval_loss += tmp_eval_loss.mean().item()
628-
eval_accuracy += tmp_eval_accuracy
629707

630-
nb_eval_examples += input_ids.size(0)
708+
eval_loss += tmp_eval_loss.mean().item()
631709
nb_eval_steps += 1
710+
if preds is None:
711+
preds = logits.detach().cpu().numpy()
712+
out_label_ids = label_ids.detach().cpu().numpy()
713+
else:
714+
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
715+
out_label_ids = np.append(out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
716+
717+
eval_loss = eval_loss / nb_eval_steps
718+
preds = np.argmax(preds, axis=1)
632719

633720
eval_loss = eval_loss / nb_eval_steps
634-
eval_accuracy = eval_accuracy / nb_eval_examples
635721
loss = tr_loss/nb_tr_steps if args.do_train else None
636-
result = {'eval_loss': eval_loss,
637-
'eval_accuracy': eval_accuracy,
722+
723+
results = {'eval_loss': eval_loss,
638724
'global_step': global_step,
639725
'loss': loss}
640726

727+
result = compute_metrics(task_name, preds, out_label_ids)
728+
results.update(result)
729+
print(results)
641730
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
642731
with open(output_eval_file, "w") as writer:
643732
logger.info("***** Eval results *****")
644-
for key in sorted(result.keys()):
645-
logger.info(" %s = %s", key, str(result[key]))
646-
writer.write("%s = %s\n" % (key, str(result[key])))
733+
for key in sorted(results.keys()):
734+
logger.info(" %s = %s", key, str(results[key]))
735+
writer.write("%s = %s\n" % (key, str(results[key])))
647736

648737
if __name__ == "__main__":
649738
main()

0 commit comments

Comments
 (0)