Skip to content

Commit da0bbe3

Browse files
authored
[float8] all-reduce amax on dp mesh instead of global pg (#933)
* [float8] all-reduce amax on dp mesh instead of global pg Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * liner Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * improve comments Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * move hp tensor inside if Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 72cc27d commit da0bbe3

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
import torch.nn as nn
1818
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
1919
from torchao.float8.float8_linear_utils import convert_to_float8_training
20+
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
2021
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
2122
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
2223
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
23-
from torch.distributed._tensor import DTensor
24+
from torch.distributed._tensor import DTensor, init_device_mesh
25+
from torchao.float8.float8_tensor import GemmInputRole
2426
from torch.testing._internal.common_cuda import TEST_CUDA
2527
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
2628
from torch.testing._internal.common_fsdp import (
@@ -293,6 +295,34 @@ def _get_curr_active_memory_mb(self) -> int:
293295
return round(mem_stats["active_bytes.all.current"] / 1e6)
294296

295297

298+
class Test2DParallelMultiThread(FSDPTestMultiThread, TestFloat8Common):
299+
@property
300+
def world_size(self) -> int:
301+
return 4
302+
303+
def test_amax_allreduce_device_mesh(self):
304+
dp_size = 2
305+
pp_size = self.world_size // dp_size
306+
global_mesh = init_device_mesh("cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp"))
307+
dp_mesh = global_mesh["dp"]
308+
pp_mesh = global_mesh["pp"]
309+
310+
if self.rank in [0, 1]:
311+
# rank 0 and 1 are the 1st stage in the pipeline
312+
# rank 2 and 4 are doing nothing but waiting for the 1st stage
313+
torch.manual_seed(42 + self.rank)
314+
hp_tensor = torch.randn(768, 32, device="cuda")
315+
float8_tensor = hp_tensor_to_float8_dynamic(
316+
hp_tensor,
317+
torch.float8_e4m3fn,
318+
Float8LinearConfig(
319+
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
320+
),
321+
gemm_input_role=GemmInputRole.WEIGHT,
322+
reduce_amax=True,
323+
device_mesh=dp_mesh
324+
)
325+
296326
class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
297327
@property
298328
def world_size(self) -> int:

torchao/float8/float8_scaling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
3636
linear_mm_config: LinearMMConfig,
3737
reduce_amax: bool = False,
3838
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
39+
device_mesh = None,
3940
) -> Float8Tensor:
4041
"""
4142
Given a high precision tensor `hp_tensor`,
@@ -52,7 +53,7 @@ def hp_tensor_to_float8_dynamic(
5253
"""
5354
if tensor_already_casted_to_fp8(hp_tensor):
5455
return hp_tensor
55-
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
56+
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax, device_mesh)
5657
return hp_tensor_and_scale_to_float8(
5758
hp_tensor,
5859
scale,

torchao/float8/float8_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,29 @@ def amax_history_to_scale_stack(
9898

9999

100100
@torch.no_grad()
101-
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
101+
def tensor_to_amax(
102+
x: torch.Tensor, reduce_amax: bool = False, device_mesh=None
103+
) -> torch.Tensor:
102104
amax = torch.max(torch.abs(x))
103105

104106
# If the user asked for distributed reduction, do it.
105107
# If the user did not ask for it, assume that it will
106108
# happen elsewhere.
107109
if reduce_amax and dist.is_initialized():
108-
dist.all_reduce(amax, op=dist.ReduceOp.MAX)
110+
pg = device_mesh.get_group() if device_mesh is not None else None
111+
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg)
109112

110113
return amax
111114

112115

113116
@torch.no_grad()
114117
def tensor_to_scale(
115-
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
118+
x: torch.Tensor,
119+
float8_dtype: torch.dtype,
120+
reduce_amax: bool = False,
121+
device_mesh=None,
116122
) -> torch.Tensor:
117-
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
123+
amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh)
118124
return amax_to_scale(amax, float8_dtype, x.dtype)
119125

120126

torchao/float8/fsdp_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def fsdp_pre_all_gather(self, mesh):
216216
self._linear_mm_config,
217217
reduce_amax=True,
218218
gemm_input_role=GemmInputRole.WEIGHT,
219+
device_mesh=mesh,
219220
)
220221
return (float8_tensor._data,), (float8_tensor._scale,)
221222

0 commit comments

Comments
 (0)