Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,51 @@ def run_all_reduce(local_rank, args):
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

# Prepare benchmark header
print_header(args, 'all_reduce')
# Prepare benchmark header unless validating
if not args.validate:
print_header(args, 'all_reduce')
else:
print_rank_0("Running Allreduce validation")

world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
if args.single:
sync_all()
M = 2 ** (args.maxsize-1)
try:
mat = torch.ones(world_size, M,
dtype=getattr(torch, args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
del mat
torch.cuda.empty_cache()
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
return
else:
raise e
sync_all()
if args.validate:
passes = 0
for _ in range(args.trials):
if validate_allreduce(input.clone(), args):
passes += 1
size = input.element_size() * input.nelement()
if not args.raw:
size = convert_size(size)
desc = f"validation ({passes}/{args.trials})"
print_rank_0(f"{size:<20} {desc:25s} {'PASS' if passes == args.trials else 'FAIL'}")
else:
timed_all_reduce(input, start_event, end_event, args)

elif args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)
Expand All @@ -76,13 +111,24 @@ def run_all_reduce(local_rank, args):
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
print('WARNING: Ran out of GPU memory.')
sync_all()
break
else:
raise e
sync_all()
timed_all_reduce(input, start_event, end_event, args)
if args.validate:
passes = 0
for _ in range(args.trials):
if validate_allreduce(input.clone(), args):
passes += 1
size = input.element_size() * input.nelement()
if not args.raw:
size = convert_size(size)
desc = f"validation ({passes}/{args.trials})"
print_rank_0(f"{size:<20} {desc:25s} {'PASS' if passes == args.trials else 'FAIL'}")
else:
timed_all_reduce(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -104,7 +150,19 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, start_event, end_event, args)
if args.validate:
passes = 0
for _ in range(args.trials):
if validate_allreduce(input.clone(), args):
passes += 1
size = input.element_size() * input.nelement()
if not args.raw:
size = convert_size(size)
desc = f"validation ({passes}/{args.trials})"
print_rank_0(f"{size:<20} {desc:25s} {'PASS' if passes == args.trials else 'FAIL'}")
else:
timed_all_reduce(input, start_event, end_event, args)



if __name__ == "__main__":
Expand Down
17 changes: 16 additions & 1 deletion benchmarks/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def benchmark_parser():
parser.add_argument("--trials", type=int, default=DEFAULT_TRIALS, help='Number of timed iterations')
parser.add_argument("--warmups", type=int, default=DEFAULT_WARMUPS, help='Number of warmup (non-timed) iterations')
parser.add_argument("--maxsize", type=int, default=24, help='Max message size as a power of 2')
group = parser.add_mutually_exclusive_group()
group.add_argument("--scan", action="store_true", help='Enables scanning all message sizes')
group.add_argument("--single", action="store_true", help='Run only 2^maxsize, mutual exclusive with --scan')
parser.add_argument("--async-op", action="store_true", help='Enables non-blocking communication')
parser.add_argument("--bw-unit", type=str, default=DEFAULT_UNIT, choices=['Gbps', 'GBps'])
parser.add_argument("--backend",
Expand All @@ -219,7 +222,6 @@ def benchmark_parser():
default=DEFAULT_DIST,
choices=['deepspeed', 'torch'],
help='Distributed DL framework to use')
parser.add_argument("--scan", action="store_true", help='Enables scanning all message sizes')
parser.add_argument("--raw", action="store_true", help='Print the message size and latency without units')
parser.add_argument("--all-reduce", action="store_true", help='Run all_reduce')
parser.add_argument("--reduce-scatter", action="store_true", help='Run reduce_scatter')
Expand All @@ -235,4 +237,17 @@ def benchmark_parser():
parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints')
parser.add_argument('--all-to-all-v', action='store_true',
help='Use alltoallv instead of alltoall. This will run the all_to_all benchmark with vector variant. Use with --all-to-all or alone to run just this benchmark.')
parser.add_argument("--validate", action="store_true", help='Validate collective results')
return parser

def validate_allreduce(input, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

dist.all_reduce(input, async_op=False)
sync_all()
n = dist.get_world_size()
expected = float(n * (n - 1) / 2)
return torch.allclose(input, torch.full_like(input, expected))