Skip to content

Commit b51e8c7

Browse files
committed
Fixex for DDP setting
- Only download nltk stuff on rank zero - For some reason, in DDP I got "AttributeError: Can't pickle local object 'TorchGraph.create_forward_hook.<locals>.after_forward_hook" when using W&B to track gradients. So, I removed this completely for now
1 parent 165e4a3 commit b51e8c7

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

train.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from src.model import CMCModule
2020
from src.utils import WandbOrganizer
2121

22-
nltk.download("wordnet")
23-
2422

2523
def get_world_size(accelerator: str, devices: Any) -> int:
2624
if accelerator == "cpu":
@@ -91,6 +89,8 @@ def main(cfg: TrainConfig) -> None:
9189
)
9290

9391
if local_rank == 0:
92+
nltk.download("wordnet")
93+
9494
if cfg.logger.use_wandb and cfg.model.configuration == "race":
9595
# download model checkpoint
9696
artifact = wandb.use_artifact(
@@ -155,9 +155,6 @@ def main(cfg: TrainConfig) -> None:
155155
)
156156
cfg.optimizer.learning_rate = model.learning_rate
157157

158-
if cfg.logger.use_wandb:
159-
trainer_logger.watch(model, log="gradients", log_freq=250)
160-
161158
# callbacks
162159
lr_logger = LearningRateMonitor(logging_interval="step")
163160
checkpoint_callback = ModelCheckpoint(

0 commit comments

Comments
 (0)