11# coding=utf-8
2- # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
32# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4+ #
45# Licensed under the Apache License, Version 2.0 (the "License");
56# you may not use this file except in compliance with the License.
67# You may obtain a copy of the License at
1213# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1314# See the License for the specific language governing permissions and
1415# limitations under the License.
15-
1616"""BERT finetuning runner."""
1717
1818from __future__ import absolute_import , division , print_function
1919
20+ import pickle
2021import argparse
2122import csv
2223import logging
3536from modeling import BertForSequenceClassification , BertConfig , WEIGHTS_NAME , CONFIG_NAME
3637from tokenization import BertTokenizer
3738from optimization import BertAdam , warmup_linear
39+ from schedulers import LinearWarmUpScheduler
40+ from apex import amp
41+ from sklearn .metrics import matthews_corrcoef , f1_score
42+ from utils import is_main_process
3843
3944logging .basicConfig (format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' ,
4045 datefmt = '%m/%d/%Y %H:%M:%S' ,
4146 level = logging .INFO )
4247logger = logging .getLogger (__name__ )
4348
49+ def compute_metrics (task_name , preds , labels ):
50+ assert len (preds ) == len (labels )
51+ if task_name == "cola" :
52+ return {"mcc" : matthews_corrcoef (labels , preds )}
53+ elif task_name == "sst-2" :
54+ return {"acc" : simple_accuracy (preds , labels )}
55+ elif task_name == "mrpc" :
56+ return acc_and_f1 (preds , labels )
57+ elif task_name == "sts-b" :
58+ return pearson_and_spearman (preds , labels )
59+ elif task_name == "qqp" :
60+ return acc_and_f1 (preds , labels )
61+ elif task_name == "mnli" :
62+ return {"acc" : simple_accuracy (preds , labels )}
63+ elif task_name == "mnli-mm" :
64+ return {"acc" : simple_accuracy (preds , labels )}
65+ elif task_name == "qnli" :
66+ return {"acc" : simple_accuracy (preds , labels )}
67+ elif task_name == "rte" :
68+ return {"acc" : simple_accuracy (preds , labels )}
69+ elif task_name == "wnli" :
70+ return {"acc" : simple_accuracy (preds , labels )}
71+ else :
72+ raise KeyError (task_name )
73+
74+
75+ def simple_accuracy (preds , labels ):
76+ return (preds == labels ).mean ()
77+
78+ def acc_and_f1 (preds , labels ):
79+ acc = simple_accuracy (preds , labels )
80+ f1 = f1_score (y_true = labels , y_pred = preds )
81+ return {
82+ "acc" : acc ,
83+ "f1" : f1 ,
84+ "acc_and_f1" : (acc + f1 ) / 2 ,
85+ }
4486
4587class InputExample (object ):
4688 """A single training/test example for simple sequence classification."""
@@ -298,6 +340,30 @@ def accuracy(out, labels):
298340 outputs = np .argmax (out , axis = 1 )
299341 return np .sum (outputs == labels )
300342
343+ from apex .multi_tensor_apply import multi_tensor_applier
344+ class GradientClipper :
345+ """
346+ Clips gradient norm of an iterable of parameters.
347+ """
348+ def __init__ (self , max_grad_norm ):
349+ self .max_norm = max_grad_norm
350+ if multi_tensor_applier .available :
351+ import amp_C
352+ self ._overflow_buf = torch .cuda .IntTensor ([0 ])
353+ self .multi_tensor_l2norm = amp_C .multi_tensor_l2norm
354+ self .multi_tensor_scale = amp_C .multi_tensor_scale
355+ else :
356+ raise RuntimeError ('Gradient clipping requires cuda extensions' )
357+
358+ def step (self , parameters ):
359+ l = [p .grad for p in parameters if p .grad is not None ]
360+ total_norm , _ = multi_tensor_applier (self .multi_tensor_l2norm , self ._overflow_buf , [l ], False )
361+ total_norm = total_norm .item ()
362+ if (total_norm == float ('inf' )): return
363+ clip_coef = self .max_norm / (total_norm + 1e-6 )
364+ if clip_coef < 1 :
365+ multi_tensor_applier (self .multi_tensor_scale , self ._overflow_buf , [l , l ], clip_coef )
366+
301367def main ():
302368 parser = argparse .ArgumentParser ()
303369
@@ -363,6 +429,9 @@ def main():
363429 default = 3.0 ,
364430 type = float ,
365431 help = "Total number of training epochs to perform." )
432+ parser .add_argument ("--google_pretrained" ,
433+ action = 'store_true' ,
434+ help = "Whether not to use CUDA when available" )
366435 parser .add_argument ("--max_steps" , default = - 1.0 , type = float ,
367436 help = "Total number of training steps to perform." )
368437 parser .add_argument ("--warmup_proportion" ,
@@ -379,7 +448,7 @@ def main():
379448 help = "local_rank for distributed training on gpus" )
380449 parser .add_argument ('--seed' ,
381450 type = int ,
382- default = 42 ,
451+ default = 1 ,
383452 help = "random seed for initialization" )
384453 parser .add_argument ('--gradient_accumulation_steps' ,
385454 type = int ,
@@ -395,6 +464,16 @@ def main():
395464 "Positive power of 2: static loss scaling value.\n " )
396465 parser .add_argument ('--server_ip' , type = str , default = '' , help = "Can be used for distant debugging." )
397466 parser .add_argument ('--server_port' , type = str , default = '' , help = "Can be used for distant debugging." )
467+ parser .add_argument ("--old" , action = 'store_true' , help = "use old fp16 optimizer" )
468+ parser .add_argument ('--vocab_file' ,
469+ type = str , default = None , required = True ,
470+ help = "Vocabulary mapping/file BERT was pretrainined on" )
471+ parser .add_argument ("--config_file" ,
472+ default = None ,
473+ type = str ,
474+ required = True ,
475+ help = "The BERT model config" )
476+
398477 args = parser .parse_args ()
399478
400479 if args .server_ip and args .server_port :
@@ -445,7 +524,7 @@ def main():
445524
446525 if os .path .exists (args .output_dir ) and os .listdir (args .output_dir ) and args .do_train :
447526 print ("WARNING: Output directory ({}) already exists and is not empty." .format (args .output_dir ))
448- if not os .path .exists (args .output_dir ):
527+ if not os .path .exists (args .output_dir ) and is_main_process () :
449528 os .makedirs (args .output_dir )
450529
451530 task_name = args .task_name .lower ()
@@ -457,8 +536,9 @@ def main():
457536 num_labels = num_labels_task [task_name ]
458537 label_list = processor .get_labels ()
459538
460- tokenizer = BertTokenizer .from_pretrained (args .bert_model , do_lower_case = args .do_lower_case )
461-
539+ #tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
540+ tokenizer = BertTokenizer (args .vocab_file , do_lower_case = args .do_lower_case , max_len = 512 ) # for bert large
541+
462542 train_examples = None
463543 num_train_optimization_steps = None
464544 if args .do_train :
@@ -469,25 +549,17 @@ def main():
469549 num_train_optimization_steps = num_train_optimization_steps // torch .distributed .get_world_size ()
470550
471551 # Prepare model
472- cache_dir = args .cache_dir if args .cache_dir else os .path .join (PYTORCH_PRETRAINED_BERT_CACHE , 'distributed_{}' .format (args .local_rank ))
473- model = BertForSequenceClassification .from_pretrained (args .bert_model ,
474- cache_dir = cache_dir ,
475- num_labels = num_labels )
476- model .load_state_dict (torch .load (args .init_checkpoint , map_location = 'cpu' ), strict = False )
552+ config = BertConfig .from_json_file (args .config_file )
553+ # Padding for divisibility by 8
554+ if config .vocab_size % 8 != 0 :
555+ config .vocab_size += 8 - (config .vocab_size % 8 )
477556
478- if args .fp16 :
479- model .half ()
480- model .to (device )
481- if args .local_rank != - 1 :
482- try :
483- from apex .parallel import DistributedDataParallel as DDP
484- except ImportError :
485- raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." )
486-
487- model = DDP (model )
488- elif n_gpu > 1 :
489- model = torch .nn .DataParallel (model )
557+ model = BertForSequenceClassification (config , num_labels = num_labels )
558+ print ("USING CHECKPOINT from" , args .init_checkpoint )
559+ model .load_state_dict (torch .load (args .init_checkpoint , map_location = 'cpu' )["model" ], strict = False )
560+ print ("USED CHECKPOINT from" , args .init_checkpoint )
490561
562+ model .to (device )
491563 # Prepare optimizer
492564 param_optimizer = list (model .named_parameters ())
493565 no_decay = ['bias' , 'LayerNorm.bias' , 'LayerNorm.weight' ]
@@ -496,33 +568,63 @@ def main():
496568 {'params' : [p for n , p in param_optimizer if any (nd in n for nd in no_decay )], 'weight_decay' : 0.0 }
497569 ]
498570 if args .fp16 :
571+ print ("using fp16" )
499572 try :
500- from apex .contrib .optimizers import FP16_Optimizer
501573 from apex .optimizers import FusedAdam
502574 except ImportError :
503575 raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." )
504576
505577 optimizer = FusedAdam (optimizer_grouped_parameters ,
506578 lr = args .learning_rate ,
507- bias_correction = False ,
508- max_grad_norm = 1.0 )
579+ bias_correction = False )
580+
509581 if args .loss_scale == 0 :
510- optimizer = FP16_Optimizer (optimizer , dynamic_loss_scale = True )
582+
583+ model , optimizer = amp .initialize (model , optimizer , opt_level = "O2" , keep_batchnorm_fp32 = False ,
584+ loss_scale = "dynamic" )
511585 else :
512- optimizer = FP16_Optimizer (optimizer , static_loss_scale = args .loss_scale )
586+ model , optimizer = amp .initialize (model , optimizer , opt_level = "O2" , keep_batchnorm_fp32 = False , loss_scale = args .loss_scale )
587+ scheduler = LinearWarmUpScheduler (optimizer , warmup = args .warmup_proportion , total_steps = num_train_optimization_steps )
513588
589+
590+
514591 else :
592+ print ("using fp32" )
515593 optimizer = BertAdam (optimizer_grouped_parameters ,
516594 lr = args .learning_rate ,
517595 warmup = args .warmup_proportion ,
518596 t_total = num_train_optimization_steps )
519597
598+ if args .local_rank != - 1 :
599+ try :
600+ from apex .parallel import DistributedDataParallel as DDP
601+ except ImportError :
602+ raise ImportError (
603+ "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." )
604+
605+ model = DDP (model )
606+ elif n_gpu > 1 :
607+ model = torch .nn .DataParallel (model )
608+
520609 global_step = 0
521610 nb_tr_steps = 0
522611 tr_loss = 0
523612 if args .do_train :
524- train_features = convert_examples_to_features (
525- train_examples , label_list , args .max_seq_length , tokenizer )
613+ print ("data prep" )
614+ cached_train_features_file = args .data_dir + '_{0}_{1}_{2}' .format (
615+ list (filter (None , args .bert_model .split ('/' ))).pop (), str (args .max_seq_length ), str (args .do_lower_case ))
616+ train_features = None
617+
618+ try :
619+ with open (cached_train_features_file , "rb" ) as reader :
620+ train_features = pickle .load (reader )
621+ except :
622+ train_features = convert_examples_to_features (train_examples , label_list , args .max_seq_length , tokenizer )
623+ if args .local_rank == - 1 or torch .distributed .get_rank () == 0 :
624+ logger .info (" Saving train features into cached file %s" , cached_train_features_file )
625+ with open (cached_train_features_file , "wb" ) as writer :
626+ pickle .dump (train_features , writer )
627+
526628 logger .info ("***** Running training *****" )
527629 logger .info (" Num examples = %d" , len (train_examples ))
528630 logger .info (" Batch size = %d" , args .train_batch_size )
@@ -554,42 +656,23 @@ def main():
554656 loss = loss / args .gradient_accumulation_steps
555657
556658 if args .fp16 :
557- optimizer .backward (loss )
659+ with amp .scale_loss (loss , optimizer ) as scaled_loss :
660+ scaled_loss .backward ()
558661 else :
559662 loss .backward ()
560663
561664 tr_loss += loss .item ()
562665 nb_tr_examples += input_ids .size (0 )
563666 nb_tr_steps += 1
564667 if (step + 1 ) % args .gradient_accumulation_steps == 0 :
565- if args .fp16 :
566- # modify learning rate with special warm up BERT uses
567- # if args.fp16 is False, BertAdam is used that handles this automatically
568- lr_this_step = args .learning_rate * warmup_linear (global_step / num_train_optimization_steps , args .warmup_proportion )
569- for param_group in optimizer .param_groups :
570- param_group ['lr' ] = lr_this_step
668+ if args .fp16 :
669+ # modify learning rate with special warm up for BERT which FusedAdam doesn't do
670+ scheduler .step ()
671+
571672 optimizer .step ()
572673 optimizer .zero_grad ()
573674 global_step += 1
574675
575- if args .do_train :
576- # Save a trained model and the associated configuration
577- model_to_save = model .module if hasattr (model , 'module' ) else model # Only save the model it-self
578- output_model_file = os .path .join (args .output_dir , WEIGHTS_NAME )
579- torch .save (model_to_save .state_dict (), output_model_file )
580- output_config_file = os .path .join (args .output_dir , CONFIG_NAME )
581- with open (output_config_file , 'w' ) as f :
582- f .write (model_to_save .config .to_json_string ())
583-
584- # Load a trained model and config that you have fine-tuned
585- config = BertConfig (output_config_file )
586- model = BertForSequenceClassification (config , num_labels = num_labels )
587- model .load_state_dict (torch .load (output_model_file ))
588- else :
589- model = BertForSequenceClassification .from_pretrained (args .bert_model , num_labels = num_labels )
590- model .load_state_dict (torch .load (args .init_checkpoint , map_location = 'cpu' ), strict = False )
591- model .to (device )
592-
593676 if args .do_eval and (args .local_rank == - 1 or torch .distributed .get_rank () == 0 ):
594677 eval_examples = processor .get_dev_examples (args .data_dir )
595678 eval_features = convert_examples_to_features (
@@ -609,7 +692,8 @@ def main():
609692 model .eval ()
610693 eval_loss , eval_accuracy = 0 , 0
611694 nb_eval_steps , nb_eval_examples = 0 , 0
612-
695+ preds = None
696+ out_label_ids = None
613697 for input_ids , input_mask , segment_ids , label_ids in tqdm (eval_dataloader , desc = "Evaluating" ):
614698 input_ids = input_ids .to (device )
615699 input_mask = input_mask .to (device )
@@ -620,30 +704,35 @@ def main():
620704 tmp_eval_loss = model (input_ids , segment_ids , input_mask , label_ids )
621705 logits = model (input_ids , segment_ids , input_mask )
622706
623- logits = logits .detach ().cpu ().numpy ()
624- label_ids = label_ids .to ('cpu' ).numpy ()
625- tmp_eval_accuracy = accuracy (logits , label_ids )
626-
627- eval_loss += tmp_eval_loss .mean ().item ()
628- eval_accuracy += tmp_eval_accuracy
629707
630- nb_eval_examples += input_ids . size ( 0 )
708+ eval_loss += tmp_eval_loss . mean (). item ( )
631709 nb_eval_steps += 1
710+ if preds is None :
711+ preds = logits .detach ().cpu ().numpy ()
712+ out_label_ids = label_ids .detach ().cpu ().numpy ()
713+ else :
714+ preds = np .append (preds , logits .detach ().cpu ().numpy (), axis = 0 )
715+ out_label_ids = np .append (out_label_ids , label_ids .detach ().cpu ().numpy (), axis = 0 )
716+
717+ eval_loss = eval_loss / nb_eval_steps
718+ preds = np .argmax (preds , axis = 1 )
632719
633720 eval_loss = eval_loss / nb_eval_steps
634- eval_accuracy = eval_accuracy / nb_eval_examples
635721 loss = tr_loss / nb_tr_steps if args .do_train else None
636- result = { 'eval_loss' : eval_loss ,
637- 'eval_accuracy ' : eval_accuracy ,
722+
723+ results = { 'eval_loss ' : eval_loss ,
638724 'global_step' : global_step ,
639725 'loss' : loss }
640726
727+ result = compute_metrics (task_name , preds , out_label_ids )
728+ results .update (result )
729+ print (results )
641730 output_eval_file = os .path .join (args .output_dir , "eval_results.txt" )
642731 with open (output_eval_file , "w" ) as writer :
643732 logger .info ("***** Eval results *****" )
644- for key in sorted (result .keys ()):
645- logger .info (" %s = %s" , key , str (result [key ]))
646- writer .write ("%s = %s\n " % (key , str (result [key ])))
733+ for key in sorted (results .keys ()):
734+ logger .info (" %s = %s" , key , str (results [key ]))
735+ writer .write ("%s = %s\n " % (key , str (results [key ])))
647736
648737if __name__ == "__main__" :
649738 main ()
0 commit comments