Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DDP for torch.distributed.run with gloo backend #3680

Merged
merged 35 commits into from
Jun 19, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
007902e
Update DDP for `torch.distributed.run`
glenn-jocher Jun 18, 2021
9bcb4ad
Add LOCAL_RANK
glenn-jocher Jun 18, 2021
b32bae0
remove opt.local_rank
glenn-jocher Jun 18, 2021
b467501
backend="gloo|nccl"
glenn-jocher Jun 18, 2021
c886538
print
glenn-jocher Jun 18, 2021
5d847dc
print
glenn-jocher Jun 18, 2021
26d0ecf
debug
glenn-jocher Jun 18, 2021
832ba4c
debug
glenn-jocher Jun 18, 2021
9a1bb01
os.getenv
glenn-jocher Jun 18, 2021
0e912df
gloo
glenn-jocher Jun 18, 2021
5f5e428
gloo
glenn-jocher Jun 18, 2021
e8493c6
gloo
glenn-jocher Jun 18, 2021
fb342fc
cleanup
glenn-jocher Jun 18, 2021
382ce4f
fix getenv
glenn-jocher Jun 18, 2021
b09b415
cleanup
glenn-jocher Jun 18, 2021
9c4ac05
cleanup destroy
glenn-jocher Jun 18, 2021
8ae9ea1
try nccl
glenn-jocher Jun 18, 2021
a18f933
merge master
glenn-jocher Jun 19, 2021
2435775
return opt
glenn-jocher Jun 19, 2021
56a4ab4
add --local_rank
glenn-jocher Jun 19, 2021
c4d839b
add timeout
glenn-jocher Jun 19, 2021
0584e7e
add init_method
glenn-jocher Jun 19, 2021
d917341
gloo
glenn-jocher Jun 19, 2021
6a1cc64
move destroy
glenn-jocher Jun 19, 2021
3581c76
move destroy
glenn-jocher Jun 19, 2021
5f5d122
move print(opt) under if RANK
glenn-jocher Jun 19, 2021
5451fc2
destroy only RANK 0
glenn-jocher Jun 19, 2021
9aa229e
move destroy inside train()
glenn-jocher Jun 19, 2021
94363ce
restore destroy outside train()
glenn-jocher Jun 19, 2021
9647379
update print(opt)
glenn-jocher Jun 19, 2021
cb8395d
merge master
glenn-jocher Jun 19, 2021
96686fd
cleanup
glenn-jocher Jun 19, 2021
446c610
nccl
glenn-jocher Jun 19, 2021
49bb0b7
gloo with 60 second timeout
glenn-jocher Jun 19, 2021
b5decde
update namespace printing
glenn-jocher Jun 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 41 additions & 44 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume

logger = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))


def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
):
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
opt.single_cls
save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls

# Directories
wdir = save_dir / 'weights'
Expand All @@ -69,13 +71,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Configure
plots = not opt.evolve # create plots
cuda = device.type != 'cpu'
init_seeds(2 + rank)
init_seeds(2 + RANK)
with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict

# Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict
if rank in [-1, 0]:
if RANK in [-1, 0]:
# TensorBoard
if not opt.evolve:
prefix = colorstr('tensorboard: ')
Expand All @@ -99,7 +101,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Model
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(rank):
with torch_distributed_zero_first(RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
Expand All @@ -110,7 +112,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(rank):
with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check
train_path = data_dict['train']
test_path = data_dict['val']
Expand Down Expand Up @@ -158,7 +160,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# plot_lr_scheduler(optimizer, scheduler, epochs)

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None
ema = ModelEMA(model) if RANK in [-1, 0] else None

# Resume
start_epoch, best_fitness = 0, 0.0
Expand Down Expand Up @@ -194,28 +196,28 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples

# DP mode
if cuda and rank == -1 and torch.cuda.device_count() > 1:
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

# SyncBatchNorm
if opt.sync_bn and cuda and rank != -1:
if opt.sync_bn and cuda and RANK != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
logger.info('Using SyncBatchNorm()')

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)

# Process 0
if rank in [-1, 0]:
if RANK in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0]

if not opt.resume:
Expand All @@ -234,8 +236,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
model.half().float() # pre-reduce anchor precision

# DDP mode
if cuda and rank != -1:
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank,
if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,
# nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))

Expand Down Expand Up @@ -269,27 +271,27 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Update image weights (optional)
if opt.image_weights:
# Generate indices
if rank in [-1, 0]:
if RANK in [-1, 0]:
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
# Broadcast if DDP
if rank != -1:
indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
if RANK != -1:
indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
dist.broadcast(indices, 0)
if rank != 0:
if RANK != 0:
dataset.indices = indices.cpu().numpy()

# Update mosaic border
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders

mloss = torch.zeros(4, device=device) # mean losses
if rank != -1:
if RANK != -1:
dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader)
logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
if rank in [-1, 0]:
if RANK in [-1, 0]:
pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
Expand Down Expand Up @@ -319,8 +321,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with amp.autocast(enabled=cuda):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if RANK != -1:
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
if opt.quad:
loss *= 4.

Expand All @@ -336,7 +338,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ema.update(model)

# Print
if rank in [-1, 0]:
if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
Expand All @@ -362,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
scheduler.step()

# DDP process 0 or single-GPU
if rank in [-1, 0]:
if RANK in [-1, 0]:
# mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
Expand Down Expand Up @@ -424,7 +426,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# end epoch ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------
if rank in [-1, 0]:
if RANK in [-1, 0]:
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png
Expand Down Expand Up @@ -486,7 +488,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
parser.add_argument('--project', default='runs/train', help='save to project/name')
parser.add_argument('--entity', default=None, help='W&B entity')
Expand All @@ -501,11 +502,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
opt = parser.parse_args()

# Set DDP variables
opt.world_size = int(getattr(os.environ, 'WORLD_SIZE', 1))
opt.global_rank = int(getattr(os.environ, 'RANK', -1))
set_logging(opt.global_rank)
if opt.global_rank in [-1, 0]:
set_logging(RANK)
if RANK in [-1, 0]:
check_git_status()
check_requirements(exclude=['thop'])

Expand All @@ -514,11 +512,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if opt.resume and not wandb_run: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \
'', ckpt, True, opt.total_batch_size, *apriori # reinstate
opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate
logger.info('Resuming training from %s' % ckpt)
else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
Expand All @@ -531,14 +527,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# DDP mode
opt.total_batch_size = opt.batch_size
device = select_device(opt.device, batch_size=opt.batch_size)
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device('cuda', opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
print({'RANK': RANK, 'LOCAL_RANK': LOCAL_RANK, 'WORLD_SIZE': WORLD_SIZE})
if LOCAL_RANK != -1:
assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command'
torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="gloo") # distributed backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nccl should be the faster backend for ddp. I recall that Windows only support gloo however.

assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // opt.world_size
opt.batch_size = opt.total_batch_size // WORLD_SIZE

# Train
logger.info(opt)
Expand Down Expand Up @@ -579,7 +576,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
Expand Down
4 changes: 2 additions & 2 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def exif_size(img):


def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
Expand All @@ -79,7 +79,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
prefix=prefix)

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
Expand Down
6 changes: 4 additions & 2 deletions utils/wandb_logging/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities and tools for tracking runs with Weights & Biases."""
import logging
import os
import sys
from contextlib import contextmanager
from pathlib import Path
Expand All @@ -18,6 +19,7 @@
except ImportError:
wandb = None

RANK = int(os.getenv('RANK', -1))
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'


Expand All @@ -42,10 +44,10 @@ def get_run_info(run_path):


def check_wandb_resume(opt):
process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if opt.global_rank not in [-1, 0]: # For resuming DDP runs
if RANK not in [-1, 0]: # For resuming DDP runs
entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
api = wandb.Api()
artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
Expand Down