Skip to content

Commit bce07de

Browse files
XilunWupytorchmergebot
authored andcommitted
[dtensor][cp][experiment] add CP experimental API to choose rotate method (pytorch#142093)
**Summary** This PR adds a new experimental API `set_rotate_method` for Context Parallel. This API allows user to choose the desired communication method (between all-to-all and all-gather) for shards rotation. **Test** `pytest test/distributed/_tensor/test_attention.py` Pull Request resolved: pytorch#142093 Approved by: https://github.com/fegin
1 parent eb84788 commit bce07de

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

test/distributed/_tensor/test_attention.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_RotateMethod,
1616
context_parallel,
1717
context_parallel_unshard,
18+
set_rotate_method,
1819
)
1920
from torch.distributed.tensor.debug import CommDebugMode
2021
from torch.distributed.tensor.parallel import parallelize_module
@@ -48,6 +49,12 @@
4849
backends.append(SDPBackend.EFFICIENT_ATTENTION)
4950

5051

52+
rotater_enum_to_str = {
53+
_RotateMethod.ALL_GATHER: "allgather",
54+
_RotateMethod.ALL_TO_ALL: "alltoall",
55+
} # mapping from _RotateMethod enum to string
56+
57+
5158
class RingAttentionTest(DTensorTestBase):
5259
@property
5360
def world_size(self) -> int:
@@ -76,7 +83,8 @@ def test_ring_attention_sdpa(
7683
load_balance: bool,
7784
rotater: _RotateMethod,
7885
) -> None:
79-
_cp_options.rotate_method = rotater
86+
set_rotate_method(rotater_enum_to_str[rotater])
87+
self.assertEqual(_cp_options.rotate_method, rotater)
8088
device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
8189
dtype = torch.bfloat16
8290
bs = 8
@@ -230,7 +238,8 @@ def test_ring_attention_native_transformer(
230238
self, is_causal: bool, rotater: _RotateMethod
231239
) -> None:
232240
_cp_options.enable_load_balance = is_causal
233-
_cp_options.rotate_method = rotater
241+
set_rotate_method(rotater_enum_to_str[rotater])
242+
self.assertEqual(_cp_options.rotate_method, rotater)
234243
device_mesh = DeviceMesh(
235244
self.device_type,
236245
torch.arange(0, self.world_size),
@@ -314,7 +323,8 @@ def test_ring_attention_native_transformer(
314323
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
315324
@parametrize("rotater", [_RotateMethod.ALL_GATHER, _RotateMethod.ALL_TO_ALL])
316325
def test_ring_attention_custom_transformer(self, rotater: _RotateMethod) -> None:
317-
_cp_options.rotate_method = rotater
326+
set_rotate_method(rotater_enum_to_str[rotater])
327+
self.assertEqual(_cp_options.rotate_method, rotater)
318328
device_mesh = DeviceMesh(
319329
self.device_type,
320330
torch.arange(0, self.world_size),

torch/distributed/tensor/experimental/_attention.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch.distributed.tensor.parallel.style import ParallelStyle
3232

3333

34-
__all__ = ["context_parallel"]
34+
__all__ = ["context_parallel", "set_rotate_method"]
3535

3636

3737
class _CausalBehavior(Enum):
@@ -1284,10 +1284,46 @@ def context_parallel_unshard(
12841284
) -> List[torch.Tensor]:
12851285
"""
12861286
Unshard the tensors (e.g., output) that are sharded due to context parallelism.
1287+
1288+
Args:
1289+
mesh (:class:`DeviceMesh`): the device mesh for the context parallelism.
1290+
buffers (List[torch.Tensor]): the buffers to be unsharded.
1291+
seq_dims (List[int]): the sequence dimensions of ``buffers``. This list
1292+
must have the same length as ``buffers``.
1293+
1294+
Returns:
1295+
List[torch.Tensor]: the unsharded buffers.
12871296
"""
12881297
sharder = (
12891298
_RoundRobinLoadBalancer
12901299
if _cp_options.enable_load_balance
12911300
else _SequentialSharder
12921301
)
12931302
return [sharder.unshard(b, mesh, dim) for b, dim in zip(buffers, seq_dims)]
1303+
1304+
1305+
def set_rotate_method(rotate_method: str) -> None:
1306+
"""
1307+
Context Parallel SDPA requires the rotation of kv shards. Users can call this
1308+
API to specify which rotation method to use. "alltoall" shuffles the kv shards
1309+
using all-to-all collective. While "allgather" gathers the kv shards using
1310+
all-gather collective after the first sub-SDPA computation. If this API has not
1311+
been called, the default rotate method is "allgather".
1312+
1313+
Args:
1314+
rotate_method (str): the rotate method to use. Currently only supports
1315+
"allgather" and "alltoall". If a different string other than these two
1316+
is passed in, the function will raise an error.
1317+
1318+
Returns:
1319+
None
1320+
"""
1321+
if rotate_method == "allgather":
1322+
_cp_options.rotate_method = _RotateMethod.ALL_GATHER
1323+
elif rotate_method == "alltoall":
1324+
_cp_options.rotate_method = _RotateMethod.ALL_TO_ALL
1325+
else:
1326+
raise NotImplementedError(
1327+
"Context Parallel does not support "
1328+
f"using {rotate_method} for kv shards rotation"
1329+
)

0 commit comments

Comments
 (0)