1
1
import logging
2
2
import os
3
3
import sys
4
+ import torch
4
5
from transformers import AutoTokenizer
5
6
from transformers import (
6
7
HfArgumentParser ,
7
8
set_seed ,
8
9
)
10
+ from torch .nn .parallel import DistributedDataParallel as DDP
11
+ import torch .distributed as dist
9
12
from tevatron .reranker .arguments import ModelArguments , DataArguments , TevatronTrainingArguments
10
13
from tevatron .reranker .modeling import RerankerModel
11
14
from tevatron .reranker .dataset import RerankerTrainDataset
12
15
from tevatron .reranker .collator import RerankerTrainCollator
13
- from tevatron .reranker .trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
16
+ from tevatron .reranker .trainer import RerankerTrainer
14
17
15
18
logger = logging .getLogger (__name__ )
16
19
17
20
21
+ def setup_ddp ():
22
+ if not dist .is_initialized ():
23
+ dist .init_process_group (backend = "nccl" )
24
+ local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
25
+ torch .cuda .set_device (local_rank )
26
+ return local_rank
27
+
28
+
18
29
def main ():
19
30
parser = HfArgumentParser ((ModelArguments , DataArguments , TevatronTrainingArguments ))
20
31
@@ -23,29 +34,23 @@ def main():
23
34
else :
24
35
model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
25
36
26
- if (
27
- os .path .exists (training_args .output_dir )
28
- and os .listdir (training_args .output_dir )
29
- and training_args .do_train
30
- and not training_args .overwrite_output_dir
31
- ):
32
- raise ValueError (
33
- f"Output directory ({ training_args .output_dir } ) already exists and is not empty. Use --overwrite_output_dir to overcome."
34
- )
37
+ local_rank = - 1
38
+ if training_args .local_rank != - 1 :
39
+ local_rank = setup_ddp ()
35
40
36
41
# Setup logging
37
42
logging .basicConfig (
38
43
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
39
44
datefmt = "%m/%d/%Y %H:%M:%S" ,
40
- level = logging .INFO if training_args . local_rank in [- 1 , 0 ] else logging .WARN ,
45
+ level = logging .INFO if local_rank in [- 1 , 0 ] else logging .WARN ,
41
46
)
42
47
logger .warning (
43
48
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" ,
44
- training_args . local_rank ,
49
+ local_rank ,
45
50
training_args .device ,
46
51
training_args .n_gpu ,
47
- bool (training_args . local_rank != - 1 ),
48
- training_args .fp16 ,
52
+ bool (local_rank != - 1 ),
53
+ training_args .fp16 or training_args . bf16 ,
49
54
)
50
55
logger .info ("Training/evaluation parameters %s" , training_args )
51
56
logger .info ("MODEL parameters %s" , model_args )
@@ -67,11 +72,16 @@ def main():
67
72
cache_dir = model_args .cache_dir ,
68
73
)
69
74
75
+ # Move model to GPU
76
+ if local_rank != - 1 :
77
+ model = model .to (local_rank )
78
+ model = DDP (model , device_ids = [local_rank ], output_device = local_rank )
79
+
70
80
train_dataset = RerankerTrainDataset (data_args )
71
81
train_collator = RerankerTrainCollator (data_args , tokenizer )
72
82
73
- # Add GradCache-specific arguments to training_args
74
83
training_args .gc_chunk_size = getattr (training_args , 'gc_chunk_size' , 2 )
84
+ training_args .grad_cache = getattr (training_args , 'grad_cache' , False )
75
85
76
86
trainer = RerankerTrainer (
77
87
model = model ,
@@ -81,11 +91,11 @@ def main():
81
91
)
82
92
train_dataset .trainer = trainer
83
93
84
- trainer .train () # TODO: resume training
94
+ trainer .train ()
85
95
trainer .save_model ()
86
96
if trainer .is_world_process_zero ():
87
97
tokenizer .save_pretrained (training_args .output_dir )
88
98
89
99
90
100
if __name__ == "__main__" :
91
- main ()
101
+ main ()
0 commit comments