Skip to content

Commit 0380e4e

Browse files
committed
fix: ddp
1 parent 43a642d commit 0380e4e

File tree

3 files changed

+53
-36
lines changed

3 files changed

+53
-36
lines changed
Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
11
import logging
22
import os
33
import sys
4+
import torch
45
from transformers import AutoTokenizer
56
from transformers import (
67
HfArgumentParser,
78
set_seed,
89
)
10+
from torch.nn.parallel import DistributedDataParallel as DDP
11+
import torch.distributed as dist
912
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
1013
from tevatron.reranker.modeling import RerankerModel
1114
from tevatron.reranker.dataset import RerankerTrainDataset
1215
from tevatron.reranker.collator import RerankerTrainCollator
13-
from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
16+
from tevatron.reranker.trainer import RerankerTrainer
1417

1518
logger = logging.getLogger(__name__)
1619

1720

21+
def setup_ddp():
22+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
23+
# We're running in a distributed environment
24+
import torch.distributed as dist
25+
rank = int(os.environ['RANK'])
26+
world_size = int(os.environ['WORLD_SIZE'])
27+
dist.init_process_group(backend="nccl")
28+
return rank
29+
else:
30+
# We're not running in a distributed environment
31+
return -1
32+
33+
1834
def main():
1935
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))
2036

@@ -23,29 +39,22 @@ def main():
2339
else:
2440
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
2541

26-
if (
27-
os.path.exists(training_args.output_dir)
28-
and os.listdir(training_args.output_dir)
29-
and training_args.do_train
30-
and not training_args.overwrite_output_dir
31-
):
32-
raise ValueError(
33-
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
34-
)
42+
local_rank = setup_ddp()
43+
training_args.local_rank = local_rank
3544

3645
# Setup logging
3746
logging.basicConfig(
3847
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
3948
datefmt="%m/%d/%Y %H:%M:%S",
40-
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
49+
level=logging.INFO if local_rank in [-1, 0] else logging.WARN,
4150
)
4251
logger.warning(
4352
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
44-
training_args.local_rank,
53+
local_rank,
4554
training_args.device,
4655
training_args.n_gpu,
47-
bool(training_args.local_rank != -1),
48-
training_args.fp16,
56+
bool(local_rank != -1),
57+
training_args.fp16 or training_args.bf16,
4958
)
5059
logger.info("Training/evaluation parameters %s", training_args)
5160
logger.info("MODEL parameters %s", model_args)
@@ -67,11 +76,16 @@ def main():
6776
cache_dir=model_args.cache_dir,
6877
)
6978

79+
# Move model to GPU
80+
if local_rank != -1:
81+
model = model.to(local_rank)
82+
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
83+
7084
train_dataset = RerankerTrainDataset(data_args)
7185
train_collator = RerankerTrainCollator(data_args, tokenizer)
7286

73-
# Add GradCache-specific arguments to training_args
7487
training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2)
88+
training_args.grad_cache = getattr(training_args, 'grad_cache', False)
7589

7690
trainer = RerankerTrainer(
7791
model=model,
@@ -81,11 +95,11 @@ def main():
8195
)
8296
train_dataset.trainer = trainer
8397

84-
trainer.train() # TODO: resume training
98+
trainer.train()
8599
trainer.save_model()
86100
if trainer.is_world_process_zero():
87101
tokenizer.save_pretrained(training_args.output_dir)
88102

89103

90104
if __name__ == "__main__":
91-
main()
105+
main()

src/tevatron/reranker/modeling.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwa
3838
scores=outputs.logits
3939
)
4040

41-
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs = None):
42-
return False
43-
4441
@classmethod
4542
def build(
4643
cls,

src/tevatron/reranker/trainer.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from transformers.trainer_utils import PredictionOutput
88

99
from grad_cache import GradCache
10-
1110
from grad_cache.functional import cached, cat_input_tensor
1211
from torch.cuda.amp import autocast
1312

@@ -39,22 +38,26 @@ def split_inputs(model_input, chunk_size):
3938
class RerankerTrainer(Trainer):
4039
def __init__(self, *args, **kwargs):
4140
super().__init__(*args, **kwargs)
42-
logger.info("Initializing RerankerTrainer with GradCache")
41+
logger.info("Initializing RerankerTrainer")
4342
self.args: TrainingArguments
4443

45-
# Add these lines to include the necessary parameters
46-
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided
44+
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4)
45+
self.use_grad_cache = getattr(self.args, 'grad_cache', False)
46+
47+
if self.use_grad_cache:
48+
# If the model is wrapped in DDP, we need to use the .module attribute
49+
model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model
4750

48-
self.gc = GradCache(
49-
models=[self.model],
50-
chunk_sizes=self.gc_chunk_size,
51-
loss_fn=contrastive_loss,
52-
split_input_fn=split_inputs,
53-
get_rep_fn=lambda x: x.scores,
54-
fp16=self.args.fp16,
55-
scaler=self.scaler if self.args.fp16 else None
56-
)
57-
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")
51+
self.gc = GradCache(
52+
models=[model_for_gc],
53+
chunk_sizes=self.gc_chunk_size,
54+
loss_fn=contrastive_loss,
55+
split_input_fn=split_inputs,
56+
get_rep_fn=lambda x: x.scores,
57+
fp16=self.args.fp16,
58+
# scaler: GradScaler = None,
59+
)
60+
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")
5861

5962
def compute_loss(self, model, inputs, return_outputs=False):
6063
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
@@ -68,8 +71,11 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
6871
logger.debug("Entering training step")
6972
model.train()
7073
inputs = self._prepare_inputs(inputs)
71-
_distributed = self.args.local_rank > -1
72-
loss = self.gc(inputs, no_sync_except_last=_distributed)
74+
if self.use_grad_cache:
75+
_distributed = self.args.local_rank != -1
76+
loss = self.gc(inputs, no_sync_except_last=_distributed)
77+
else:
78+
loss = self.compute_loss(model, inputs)
7379
logger.debug(f"Training step loss: {loss.item()}")
7480
return loss
7581

0 commit comments

Comments
 (0)