diff --git a/distributed/minGPT-ddp/mingpt/main.py b/distributed/minGPT-ddp/mingpt/main.py index 861a69e1e1..dae03853ae 100644 --- a/distributed/minGPT-ddp/mingpt/main.py +++ b/distributed/minGPT-ddp/mingpt/main.py @@ -1,10 +1,12 @@ +import torch +from torch.utils.data import random_split +from torch.distributed import init_process_group, destroy_process_group from model import GPT, GPTConfig, OptimizerConfig, create_optimizer from trainer import Trainer, TrainerConfig from char_dataset import CharDataset, DataConfig -from torch.utils.data import random_split from omegaconf import DictConfig import hydra -from torch.distributed import init_process_group, destroy_process_group + def ddp_setup(): init_process_group(backend="nccl")