Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions autoparallel/collective_runtime_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from torch.distributed.tensor.placement_types import Partial, Shard

from .compute_estimation import _get_device_gmem_bandwidth
from .compute_estimation import compute_read_write_time


def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
Expand Down Expand Up @@ -65,16 +65,11 @@ def redistribute_cost(
comm_bytes_gb = (
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
)
gpu_memory_bandwidth = _get_device_gmem_bandwidth() / 1024**3 # GB/s
# Transformation that considered for redistribute cost:
# 1. allgather 2. alltoall
# 3. allreduce 4. reduce_scatter
curr_placements = [current_spec.placements[i] for i in order]
tgt_placements = [target_spec.placements[i] for i in order]

# suppose 70% efficiency for the non-collective operators
read_write_efficiency = 0.70
kernel_launch_overhead = 7 # us
for i, current, target in zip(order, curr_placements, tgt_placements):
if current == target:
continue
Expand All @@ -90,15 +85,29 @@ def redistribute_cost(
# which corresponds to reshuffling the whole output tensor
# we multiply the cost by 2 because we need to count input and output
# reads for the reshuffle
compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
compute_cost = max(
compute_cost / read_write_efficiency, kernel_launch_overhead
)
compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
cost += compute_cost
elif current.is_shard() and target.is_shard():
current = cast(Shard, current)
target = cast(Shard, target)
# should be alltoall comm, since we haven't implement it yet, add penalty
# to favor allgather instead
cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us

num_copies = 0
is_contiguous = False
if not is_contiguous:
num_copies += 1

if current.dim != 0:
num_copies += 1

if target.dim != 0:
num_copies += 1

compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
cost += num_copies * compute_cost

elif current.is_partial() and target.is_replicate():
# add up allreduce comm cost
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
Expand All @@ -111,10 +120,7 @@ def redistribute_cost(
# which corresponds to reshuffling the whole input tensor
# we multiply the cost by 2 because we need to count input and output
# reads for the reshuffle
compute_cost = comm_bytes_gb * 2 / gpu_memory_bandwidth * 1e6 # us
compute_cost = max(
compute_cost / read_write_efficiency, kernel_launch_overhead
)
compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
cost += compute_cost
# after reduce_scatter the comm bytes for further collectives halved.
comm_bytes_gb /= num_devices_on_mesh_dim
Expand Down
29 changes: 17 additions & 12 deletions autoparallel/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,21 @@ def _compute_flops(fn, *args, **kwargs):
return flop_counter.get_total_flops(), out


def compute_read_write_time(read_write_bytes):
gpu_memory_bandwidth = _get_device_gmem_bandwidth()
read_write_time = read_write_bytes / gpu_memory_bandwidth * 1e6 # us

# suppose 70% efficiency for the operator
read_write_efficiency = 0.70

kernel_launch_overhead = 7 # us

read_write_time = max(
read_write_time / read_write_efficiency, kernel_launch_overhead
)
return read_write_time


def estimate_strategy_runtime_cost(node, strategy):
"""
This function estimates the runtime cost of a given strategy
Expand All @@ -297,17 +312,7 @@ def estimate_strategy_runtime_cost(node, strategy):
flops, out = _compute_flops(node.target, *args, **kwargs)

read_write_bytes = compute_memory_cost(node.target, args, out)
gpu_memory_bandwidth = _get_device_gmem_bandwidth()
read_write_time = read_write_bytes / gpu_memory_bandwidth * 1e6 # us

# suppose 70% efficiency for the operator
read_write_efficiency = 0.70

kernel_launch_overhead = 7 # us

read_write_time = max(
read_write_time / read_write_efficiency, kernel_launch_overhead
)
read_write_time = compute_read_write_time(read_write_bytes)

if flops == 0:
return read_write_time
Expand All @@ -320,7 +325,7 @@ def estimate_strategy_runtime_cost(node, strategy):
# suppose 70% efficiency for the operator
compute_efficiency = 0.70
compute_time = flops / gpu_flops * 1e6 # us
compute_time = max(compute_time / compute_efficiency, kernel_launch_overhead)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this, but this is functionally the same as before because we already perform a max(..., kernel_launch_overhead) for the compute_read_write_time.

compute_time = compute_time / compute_efficiency

return max(compute_time, read_write_time)

Expand Down