1
1
import logging
2
2
import os
3
3
import sys
4
+ import torch
4
5
from transformers import AutoTokenizer
5
6
from transformers import (
6
7
HfArgumentParser ,
7
8
set_seed ,
8
9
)
10
+ from torch .nn .parallel import DistributedDataParallel as DDP
11
+ import torch .distributed as dist
9
12
from tevatron .reranker .arguments import ModelArguments , DataArguments , TevatronTrainingArguments
10
13
from tevatron .reranker .modeling import RerankerModel
11
14
from tevatron .reranker .dataset import RerankerTrainDataset
12
15
from tevatron .reranker .collator import RerankerTrainCollator
13
- from tevatron .reranker .trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
16
+ from tevatron .reranker .trainer import RerankerTrainer
14
17
15
18
logger = logging .getLogger (__name__ )
16
19
17
20
21
+ def setup_ddp ():
22
+ if 'RANK' in os .environ and 'WORLD_SIZE' in os .environ :
23
+ # We're running in a distributed environment
24
+ import torch .distributed as dist
25
+ rank = int (os .environ ['RANK' ])
26
+ world_size = int (os .environ ['WORLD_SIZE' ])
27
+ dist .init_process_group (backend = "nccl" )
28
+ return rank
29
+ else :
30
+ # We're not running in a distributed environment
31
+ return - 1
32
+
33
+
18
34
def main ():
19
35
parser = HfArgumentParser ((ModelArguments , DataArguments , TevatronTrainingArguments ))
20
36
@@ -23,29 +39,22 @@ def main():
23
39
else :
24
40
model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
25
41
26
- if (
27
- os .path .exists (training_args .output_dir )
28
- and os .listdir (training_args .output_dir )
29
- and training_args .do_train
30
- and not training_args .overwrite_output_dir
31
- ):
32
- raise ValueError (
33
- f"Output directory ({ training_args .output_dir } ) already exists and is not empty. Use --overwrite_output_dir to overcome."
34
- )
42
+ local_rank = setup_ddp ()
43
+ training_args .local_rank = local_rank
35
44
36
45
# Setup logging
37
46
logging .basicConfig (
38
47
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
39
48
datefmt = "%m/%d/%Y %H:%M:%S" ,
40
- level = logging .INFO if training_args . local_rank in [- 1 , 0 ] else logging .WARN ,
49
+ level = logging .INFO if local_rank in [- 1 , 0 ] else logging .WARN ,
41
50
)
42
51
logger .warning (
43
52
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" ,
44
- training_args . local_rank ,
53
+ local_rank ,
45
54
training_args .device ,
46
55
training_args .n_gpu ,
47
- bool (training_args . local_rank != - 1 ),
48
- training_args .fp16 ,
56
+ bool (local_rank != - 1 ),
57
+ training_args .fp16 or training_args . bf16 ,
49
58
)
50
59
logger .info ("Training/evaluation parameters %s" , training_args )
51
60
logger .info ("MODEL parameters %s" , model_args )
@@ -67,11 +76,16 @@ def main():
67
76
cache_dir = model_args .cache_dir ,
68
77
)
69
78
79
+ # Move model to GPU
80
+ if local_rank != - 1 :
81
+ model = model .to (local_rank )
82
+ model = DDP (model , device_ids = [local_rank ], output_device = local_rank )
83
+
70
84
train_dataset = RerankerTrainDataset (data_args )
71
85
train_collator = RerankerTrainCollator (data_args , tokenizer )
72
86
73
- # Add GradCache-specific arguments to training_args
74
87
training_args .gc_chunk_size = getattr (training_args , 'gc_chunk_size' , 2 )
88
+ training_args .grad_cache = getattr (training_args , 'grad_cache' , False )
75
89
76
90
trainer = RerankerTrainer (
77
91
model = model ,
@@ -81,11 +95,11 @@ def main():
81
95
)
82
96
train_dataset .trainer = trainer
83
97
84
- trainer .train () # TODO: resume training
98
+ trainer .train ()
85
99
trainer .save_model ()
86
100
if trainer .is_world_process_zero ():
87
101
tokenizer .save_pretrained (training_args .output_dir )
88
102
89
103
90
104
if __name__ == "__main__" :
91
- main ()
105
+ main ()
0 commit comments