Skip to content

Commit 716c19b

Browse files
authored
Improve compute cost for all2all (#159)
* Add compute cost for all_to_all The all2all implementation performs additional input/output copies depending on the in_shard / out_shard dims, see https://github.com/pytorch/pytorch/blob/afdd4247a2251b3f4c2f4b402cb625f46d6784ba/torch/csrc/distributed/c10d/Functional.cpp#L597-L617 for more details * Add .contiguous cost as well Need to figure out a way of deciding if the input is contiguous or not * Refactor into helper function * Cleanup
1 parent ff9c574 commit 716c19b

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

autoparallel/collective_runtime_estimation.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from torch.distributed.tensor.placement_types import Partial, Shard
1717

18-
from .compute_estimation import _get_device_gmem_bandwidth
18+
from .compute_estimation import compute_read_write_time
1919

2020

2121
def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
@@ -65,16 +65,11 @@ def redistribute_cost(
6565
comm_bytes_gb = (
6666
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
6767
)
68-
gpu_memory_bandwidth = _get_device_gmem_bandwidth() / 1024**3 # GB/s
6968
# Transformation that considered for redistribute cost:
7069
# 1. allgather 2. alltoall
7170
# 3. allreduce 4. reduce_scatter
7271
curr_placements = [current_spec.placements[i] for i in order]
7372
tgt_placements = [target_spec.placements[i] for i in order]
74-
75-
# suppose 70% efficiency for the non-collective operators
76-
read_write_efficiency = 0.70
77-
kernel_launch_overhead = 7 # us
7873
for i, current, target in zip(order, curr_placements, tgt_placements):
7974
if current == target:
8075
continue
@@ -90,15 +85,29 @@ def redistribute_cost(
9085
# which corresponds to reshuffling the whole output tensor
9186
# we multiply the cost by 2 because we need to count input and output
9287
# reads for the reshuffle
93-
compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
94-
compute_cost = max(
95-
compute_cost / read_write_efficiency, kernel_launch_overhead
96-
)
88+
compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
9789
cost += compute_cost
9890
elif current.is_shard() and target.is_shard():
91+
current = cast(Shard, current)
92+
target = cast(Shard, target)
9993
# should be alltoall comm, since we haven't implement it yet, add penalty
10094
# to favor allgather instead
10195
cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us
96+
97+
num_copies = 0
98+
is_contiguous = False
99+
if not is_contiguous:
100+
num_copies += 1
101+
102+
if current.dim != 0:
103+
num_copies += 1
104+
105+
if target.dim != 0:
106+
num_copies += 1
107+
108+
compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
109+
cost += num_copies * compute_cost
110+
102111
elif current.is_partial() and target.is_replicate():
103112
# add up allreduce comm cost
104113
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
@@ -111,10 +120,7 @@ def redistribute_cost(
111120
# which corresponds to reshuffling the whole input tensor
112121
# we multiply the cost by 2 because we need to count input and output
113122
# reads for the reshuffle
114-
compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
115-
compute_cost = max(
116-
compute_cost / read_write_efficiency, kernel_launch_overhead
117-
)
123+
compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
118124
cost += compute_cost
119125
# after reduce_scatter the comm bytes for further collectives halved.
120126
comm_bytes_gb /= num_devices_on_mesh_dim

autoparallel/compute_estimation.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,21 @@ def _compute_flops(fn, *args, **kwargs):
281281
return flop_counter.get_total_flops(), out
282282

283283

284+
def compute_read_write_time(read_write_bytes):
285+
gpu_memory_bandwidth = _get_device_gmem_bandwidth()
286+
read_write_time = read_write_bytes / gpu_memory_bandwidth * 1e6 # us
287+
288+
# suppose 70% efficiency for the operator
289+
read_write_efficiency = 0.70
290+
291+
kernel_launch_overhead = 7 # us
292+
293+
read_write_time = max(
294+
read_write_time / read_write_efficiency, kernel_launch_overhead
295+
)
296+
return read_write_time
297+
298+
284299
def estimate_strategy_runtime_cost(node, strategy):
285300
"""
286301
This function estimates the runtime cost of a given strategy
@@ -297,17 +312,7 @@ def estimate_strategy_runtime_cost(node, strategy):
297312
flops, out = _compute_flops(node.target, *args, **kwargs)
298313

299314
read_write_bytes = compute_memory_cost(node.target, args, out)
300-
gpu_memory_bandwidth = _get_device_gmem_bandwidth()
301-
read_write_time = read_write_bytes / gpu_memory_bandwidth * 1e6 # us
302-
303-
# suppose 70% efficiency for the operator
304-
read_write_efficiency = 0.70
305-
306-
kernel_launch_overhead = 7 # us
307-
308-
read_write_time = max(
309-
read_write_time / read_write_efficiency, kernel_launch_overhead
310-
)
315+
read_write_time = compute_read_write_time(read_write_bytes)
311316

312317
if flops == 0:
313318
return read_write_time
@@ -320,7 +325,7 @@ def estimate_strategy_runtime_cost(node, strategy):
320325
# suppose 70% efficiency for the operator
321326
compute_efficiency = 0.70
322327
compute_time = flops / gpu_flops * 1e6 # us
323-
compute_time = max(compute_time / compute_efficiency, kernel_launch_overhead)
328+
compute_time = compute_time / compute_efficiency
324329

325330
return max(compute_time, read_write_time)
326331

0 commit comments

Comments
 (0)