Skip to content

Commit

Permalink
Update doc from commit fa311ec
Browse files Browse the repository at this point in the history
  • Loading branch information
torchxlabot2 committed Oct 17, 2024
1 parent 7a49baf commit 41606a5
Show file tree
Hide file tree
Showing 48 changed files with 964 additions and 44 deletions.
2 changes: 1 addition & 1 deletion master/_modules/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/core/xla_model.html
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/debug/metrics.html
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/distributed/parallel_loader.html
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/experimental/eager.html
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/runtime.html
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/torch_xla.html
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
89 changes: 89 additions & 0 deletions master/_sources/features/distop.md.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Support of Torch Distributed API in PyTorch/XLA
Before the 2.5 release, PyTorch/XLA only supported collective ops through our custom API call `torch_xla.core.xla_model.*`. In the 2.5 release, we adopt `torch.distributed.*` in PyTorch/XLA for both Dynamo and non-Dynamo cases.
## Collective ops lowering
### Collective ops lowering stack
After introducing the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), dynamo can support the collective ops with reimplementing lowering in PyTorch/XLA. The collective op is only traceable through `torch.ops._c10d_functional` call. Below is the figure that shows how the collective op, `all_reduce` in this case, is lowered between torch and torch_xla:


<img src="../_static/img/dist_op_stack.png" alt="Alt Text" width="500" height="400">

_<span style="text-decoration:underline;">Figure 1. Collective ops lowering stack</span>_

### non-Dynamo case
Collective ops are lowered through registering the `ProcessGroupXla`, which is derived from `ProcessGroup`:

```Python
# torch_xla/distributed/xla_backend.py
def _create_xla_process_group(prefix_store, rank, size, timeout):
assert not xr.is_spmd(
), "XLA backend is not supported with SPMD. Please use a CPU process group instead."
return ProcessGroupXla(prefix_store, rank, size, timeout)


def _register_xla_backend():
dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla')


class ProcessGroupXla(ProcessGroup):
...
def allreduce(self, tensors, all_reduce_options):
...
def allgather(self, output_tensors_list, input_tensors, opts=None):
...
```

The corresponding xla dist backend is initialized when we call:
```Python
def _mp_fn(rank):
dist.init_process_group("xla", init_method='xla://')

In this way, collective ops will be called based on the progress group instance:

# E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
@_exception_logger
def all_gather(tensor_list, tensor, group=None, async_op=False):
...
group = group or _get_default_group()
work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead
```

### Dynamo case
For dynamo case, certain collective ops are remapped to the new function in [pytorch/torch/distributed/_functional_collectives.py](https://github.com/pytorch/pytorch/blob/v2.5.0-rc10/torch/distributed/_functional_collectives.py#L1129-L1150). For example, `all_reduce()` will be mapped to `all_reduce_inplace()`, where eventually `torch.ops._c10d_functional.all_reduce()`. Once we reach the _c10d_functional, we can rewrite the op through PyTorch/Xla lowering:


```C++
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
std::string /*group_name*/) {...}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_reduce", all_reduce);
}
```


## API description

For release 2.5, we now support four collective operations for both Dynamo and non-Dynamo cases. Our goal is to align the distributed operation (dist op) APIs with PyTorch's upstream implementation. While the function signatures remain consistent, certain input restrictions still apply.
For instance, specifying multiple groups for distributed collective operations is not yet supported. For usage examples, refer to [test_collective_ops_tpu.py](https://github.com/pytorch/xla/blob/v2.5.0-rc10/test/pjrt/test_collective_ops_tpu.py), which demonstrates the use of dist ops in both Dynamo and non-Dynamo scenarios.
Below are the details for each operation:
```Python
dist.all_reduce(input: torch.Tensor, op: dist.ReduceOp = ReduceOp.SUM)
```
`all_reduce` performs an in-place reduction on the `input` tensor by aggregating data from all nodes.

```Python
dist.all_gather_into_tensor(output, input)
```
`all_gather_into_tensor` gathers the input tensor from all nodes and updates the `output` tensor in-place. It also returns an alias of the output.

```Python
dist.reduce_scatter_tensor(output, input, op: dist.ReduceOp = ReduceOp.SUM)
```
`reduce_scatter_tensor` reduces the input tensor across all nodes and distributes the result to the `output` tensor in-place. It returns an alias of the output.

```Python
dist.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None)
```
`all_to_all_single` function performs an all-to-all communication, updating the output tensor in-place and returning its alias.

Note: Although `output_split_sizes` and `input_split_sizes` are accepted as arguments, they must be either None or set to all 1s. This limitation reflects a compromise between maintaining PyTorch’s API signature and the constraints of the XLA AllToAll operation.
Binary file added master/_static/img/dist_op_stack.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion master/accelerators/gpu.html
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/accelerators/tpu.html
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/contribute/bazel.html
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/contribute/codegen_migration.html
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/contribute/configure-environment.html
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/contribute/op_lowering.html
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/contribute/plugins.html
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@


<div class="version">
master (2.6.0+gita0f81e5 )
master (2.6.0+gitfa311ec )
</div>


Expand Down
Loading

0 comments on commit 41606a5

Please sign in to comment.