Skip to content

Commit 14c8129

Browse files
committed
fix: ddp
1 parent 43a642d commit 14c8129

File tree

3 files changed

+53
-33
lines changed

3 files changed

+53
-33
lines changed
Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
import logging
22
import os
33
import sys
4+
import torch
45
from transformers import AutoTokenizer
56
from transformers import (
67
HfArgumentParser,
78
set_seed,
89
)
10+
from torch.nn.parallel import DistributedDataParallel as DDP
11+
import torch.distributed as dist
912
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
1013
from tevatron.reranker.modeling import RerankerModel
1114
from tevatron.reranker.dataset import RerankerTrainDataset
1215
from tevatron.reranker.collator import RerankerTrainCollator
13-
from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
16+
from tevatron.reranker.trainer import RerankerTrainer
1417

1518
logger = logging.getLogger(__name__)
1619

1720

21+
def setup_ddp():
22+
if not dist.is_initialized():
23+
dist.init_process_group(backend="nccl")
24+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
25+
torch.cuda.set_device(local_rank)
26+
return local_rank
27+
28+
1829
def main():
1930
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))
2031

@@ -23,29 +34,23 @@ def main():
2334
else:
2435
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
2536

26-
if (
27-
os.path.exists(training_args.output_dir)
28-
and os.listdir(training_args.output_dir)
29-
and training_args.do_train
30-
and not training_args.overwrite_output_dir
31-
):
32-
raise ValueError(
33-
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
34-
)
37+
local_rank = -1
38+
if training_args.local_rank != -1:
39+
local_rank = setup_ddp()
3540

3641
# Setup logging
3742
logging.basicConfig(
3843
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
3944
datefmt="%m/%d/%Y %H:%M:%S",
40-
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
45+
level=logging.INFO if local_rank in [-1, 0] else logging.WARN,
4146
)
4247
logger.warning(
4348
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
44-
training_args.local_rank,
49+
local_rank,
4550
training_args.device,
4651
training_args.n_gpu,
47-
bool(training_args.local_rank != -1),
48-
training_args.fp16,
52+
bool(local_rank != -1),
53+
training_args.fp16 or training_args.bf16,
4954
)
5055
logger.info("Training/evaluation parameters %s", training_args)
5156
logger.info("MODEL parameters %s", model_args)
@@ -67,11 +72,16 @@ def main():
6772
cache_dir=model_args.cache_dir,
6873
)
6974

75+
# Move model to GPU
76+
if local_rank != -1:
77+
model = model.to(local_rank)
78+
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
79+
7080
train_dataset = RerankerTrainDataset(data_args)
7181
train_collator = RerankerTrainCollator(data_args, tokenizer)
7282

73-
# Add GradCache-specific arguments to training_args
7483
training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2)
84+
training_args.grad_cache = getattr(training_args, 'grad_cache', False)
7585

7686
trainer = RerankerTrainer(
7787
model=model,
@@ -81,11 +91,11 @@ def main():
8191
)
8292
train_dataset.trainer = trainer
8393

84-
trainer.train() # TODO: resume training
94+
trainer.train()
8595
trainer.save_model()
8696
if trainer.is_world_process_zero():
8797
tokenizer.save_pretrained(training_args.output_dir)
8898

8999

90100
if __name__ == "__main__":
91-
main()
101+
main()

src/tevatron/reranker/modeling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __init__(self, hf_model: PreTrainedModel):
3030
self.hf_model = hf_model
3131
logger.info(f"RerankerModel initialized with config: {self.config}")
3232

33+
def gradient_checkpointing_enable(self, **kwargs):
34+
return False
35+
# self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
36+
3337
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwargs):
3438
logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}")
3539
outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

src/tevatron/reranker/trainer.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from transformers.trainer_utils import PredictionOutput
88

99
from grad_cache import GradCache
10-
1110
from grad_cache.functional import cached, cat_input_tensor
1211
from torch.cuda.amp import autocast
1312

@@ -39,22 +38,26 @@ def split_inputs(model_input, chunk_size):
3938
class RerankerTrainer(Trainer):
4039
def __init__(self, *args, **kwargs):
4140
super().__init__(*args, **kwargs)
42-
logger.info("Initializing RerankerTrainer with GradCache")
41+
logger.info("Initializing RerankerTrainer")
4342
self.args: TrainingArguments
4443

45-
# Add these lines to include the necessary parameters
46-
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided
44+
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4)
45+
self.use_grad_cache = getattr(self.args, 'grad_cache', False)
46+
47+
if self.use_grad_cache:
48+
# If the model is wrapped in DDP, we need to use the .module attribute
49+
model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model
4750

48-
self.gc = GradCache(
49-
models=[self.model],
50-
chunk_sizes=self.gc_chunk_size,
51-
loss_fn=contrastive_loss,
52-
split_input_fn=split_inputs,
53-
get_rep_fn=lambda x: x.scores,
54-
fp16=self.args.fp16,
55-
scaler=self.scaler if self.args.fp16 else None
56-
)
57-
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")
51+
self.gc = GradCache(
52+
models=[model_for_gc],
53+
chunk_sizes=self.gc_chunk_size,
54+
loss_fn=contrastive_loss,
55+
split_input_fn=split_inputs,
56+
get_rep_fn=lambda x: x.scores,
57+
fp16=self.args.fp16,
58+
# scaler: GradScaler = None,
59+
)
60+
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")
5861

5962
def compute_loss(self, model, inputs, return_outputs=False):
6063
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
@@ -68,8 +71,11 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
6871
logger.debug("Entering training step")
6972
model.train()
7073
inputs = self._prepare_inputs(inputs)
71-
_distributed = self.args.local_rank > -1
72-
loss = self.gc(inputs, no_sync_except_last=_distributed)
74+
if self.use_grad_cache:
75+
_distributed = self.args.local_rank != -1
76+
loss = self.gc(inputs, no_sync_except_last=_distributed)
77+
else:
78+
loss = self.compute_loss(model, inputs)
7379
logger.debug(f"Training step loss: {loss.item()}")
7480
return loss
7581

0 commit comments

Comments
 (0)