Skip to content

Commit

Permalink
[dtensor] update readme for prototype release (pytorch#94517)
Browse files Browse the repository at this point in the history
This PR updates the README for prototype release, remove some code
that are not available yet and use the ones that works.

Also rename to DTensor in most sentences
Pull Request resolved: pytorch#94517
Approved by: https://github.com/fegin
  • Loading branch information
wanchaol authored and pytorchmergebot committed Feb 9, 2023
1 parent 66bfcd3 commit 09598b6
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions torch/distributed/_tensor/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# PyTorch DistributedTensor (DTensor)
# PyTorch DTensor (Prototype Release)

This folder contains the DistributedTensor (a.k.a DTensor) implementation in PyTorch.
This folder contains the DTensor (a.k.a DistributedTensor) implementation in PyTorch.

## Introduction
We propose distributed tensor primitives to allow easier distributed computation authoring in SPMD(Single Program Multiple Devices) paradigm. The primitives are simple but powerful when used to express tensor distributions with both sharding and replication parallelism strategies. This could empower native Tensor parallelism among other advanced parallelism explorations. For example, to shard a big tensor across devices with 3 lines of code:
Expand All @@ -9,7 +9,10 @@ We propose distributed tensor primitives to allow easier distributed computation
import torch
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor

# Create a mesh topology with the available devices.
# Create a mesh topology with the available devices:
# 1. We can directly create the mesh using elastic launcher,
# 2. If using mp.spawn, we need to initialize the world process_group first.
# i.e. torch.distributed.init_process_group(backend="nccl", world_size=world_size)
mesh = DeviceMesh("cuda", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
# Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`.
Expand All @@ -22,52 +25,53 @@ Today there are mainly three ways to scale up distributed training: Data Paralle

An ideal scenario is that users could build their distributed program just like authoring in a single node/device, without worrying about how to do distributed training in a cluster, and our solutions could help them run distributed training in an efficient manner. For example, researchers just need to build the big transformer model, and PyTorch Distributed automatically figures out how to split the model and run pipeline parallel across different nodes, how to run data parallel and tensor parallel within each node. In order to achieve this, we need some common abstractions to distribute tensor values and distributed computations accordingly.

There're many recent works that working on tensor level parallelism to provide common abstractions, see the `Related Works` in the last section for more details. Inspired by [GSPMD](https://arxiv.org/pdf/2105.04663.pdf), [Oneflow](https://arxiv.org/pdf/2110.15032.pdf) and [TF’s DTensor](https://www.tensorflow.org/guide/dtensor_overview), we introduce DistributedTensor as the next generation of ShardedTensor to provide basic abstractions for distributing storage and computation. It serves as one of the basic building blocks for distributed program translations and describes the layout of a distributed training program. With the DistributedTensor abstraction, we can seamlessly build parallelism strategies such as tensor parallelism, DDP and FSDP.
There're many recent works that working on tensor level parallelism to provide common abstractions, see the `Related Works` in the last section for more details. Inspired by [GSPMD](https://arxiv.org/pdf/2105.04663.pdf), [Oneflow](https://arxiv.org/pdf/2110.15032.pdf) and [TF’s DTensor](https://www.tensorflow.org/guide/dtensor_overview), we introduce PyTorch DTensor as the next generation of ShardedTensor to provide basic abstractions for distributing storage and computation. It serves as one of the basic building blocks for distributed program translations and describes the layout of a distributed training program. With the DTensor abstraction, we can seamlessly build parallelism strategies such as tensor parallelism, DDP and FSDP.

## Value Propsition

DistributedTensor primarily:
PyTorch DTensor primarily:
- Offers a uniform way to save/load `state_dict` during checkpointing, even when there’re complex tensor storage distribution strategies such as combining tensor parallelism with parameter sharding in FSDP.
- Enables Tensor Parallelism in eager mode. Compared to ShardedTensor, DistributedTensor allows additional flexibility to mix sharding and replication.
- Serves as the entry point of an SPMD programming model and the foundational building block for compiler-based distributed training.

## PyTorch DistributedTensor
## PyTorch DTensor

### DistributedTensor API
### DTensor API

We offer both a lower level DistributedTensor API and a module level API to create a `nn.Module` with “distributed” parameters.

#### Basic DistributedTensor API Examples
#### Basic DTensor API Examples

Here are some basic DistributedTensor API examples that showcase:
1. How to construct a DistributedTensor directly, to represent different types of sharding, replication, sharding + replication strategies.
2. How to create DistributedTensor from a local `torch.Tensor`.
3. How to “reshard” an existing DistributedTensor to a different DistributedTensor with modified placement strategy or world size.
Here are some basic DTensor API examples that showcase:
1. How to construct a DTensor directly, to represent different types of sharding, replication, sharding + replication strategies.
2. How to create DTensor from a local `torch.Tensor`.
3. How to “reshard” an existing DTensor to a different DTensor with modified placement strategy or world size.

```python
import torch
import torch.distributed as distributed
from torch.distributed._tensor import DTensor, DeviceMesh, Shard, Replicate, distribute_module
from torch.distributed._tensor import DTensor, DeviceMesh, Shard, Replicate, distribute_tensor, distribute_module

# construct a device mesh with available devices (multi-host or single host)
device_mesh = DeviceMesh(device_type="cuda", [0, 1, 2, 3])
device_mesh = DeviceMesh("cuda", [0, 1, 2, 3])
# if we want to do row-wise sharding
rowwise_placement=[Shard(0)]
# if we want to do col-wise sharding
colwise_placement=[Shard(1)]

big_tensor = torch.randn(888, 12)
# distributed tensor returned will be sharded across the dimension specified in placements
distributed.empty((8, 12), device_mesh=device_mesh, placements=rowwise_placement)
rowwise_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=rowwise_placement)

# if we want to do replication across a certain device list
replica_placement = [Replicate()]
# distributed tensor will be replicated to all four GPUs.
distributed.empty((8, 12), device_mesh=device_mesh, placements=replica_placement)
replica_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=replica_placement)

# if we want to distributed a tensor with both replication and sharding
device_mesh = DeviceMesh(device_type="cuda", [[0, 1], [2, 3]])
device_mesh = DeviceMesh("cuda", [[0, 1], [2, 3]])
# replicate across the first dimension of device mesh, then sharding on the second dimension of device mesh
spec=[Replicate(), Shard(0)]
distributed.empty((8, 8), device_mesh=device_mesh, placements=spec)
partial_replica = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=spec)

# create a DistributedTensor that shards on dim 0, from a local torch.Tensor
local_tensor = torch.randn((8, 8), requires_grad=True)
Expand All @@ -81,7 +85,7 @@ replica_tensor = colwise_tensor.redistribute(device_mesh, replica_placement)

#### High level User Facing APIs

Users can use DistributedTensor tensor constructors directly to create a distributed tensor (i.e. `distributed.ones/empty`), but for existing modules like `nn.Linear` that are already having `torch.Tensor` as parameters, how to make them distributed parameters? We offer a way to directly distribute a `torch.Tensor` and a module level APIs to directly distribute the module parameters. Below is the high level API we introduce:
Users can use DTensor tensor constructors directly to create a distributed tensor (i.e. `distributed.ones/empty`), but for existing modules like `nn.Linear` that are already having `torch.Tensor` as parameters, how to make them distributed parameters? We offer a way to directly distribute a `torch.Tensor` and a module level APIs to directly distribute the module parameters. Below is the high level API we introduce:

```python
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh=None, placements: List[Placement]=None):
Expand Down Expand Up @@ -132,11 +136,11 @@ def shard_fc(mod_name, mod, mesh):
sharded_module = distribute_module(model, device_mesh, partition_fn=shard_fc)
```

## Compiler and DistributedTensor
## Compiler and PyTorch DTensor

DistributedTensor provides efficient solutions for cases like Tensor Parallelism. But when using the DTensor's replication in a data parallel fashion, it might become observably slower compared to our existing solutions like DDP/FSDP. This is mainly because mainly because DDP/FSDP have a global view of the entire model architecture, thus could optimize for data parallel specifically, i.e. collective fusion and computation overlap, etc. In contract, DistributedTensor as a Tensor-like object can only optimize within individual tensor operations.
DTensor provides efficient solutions for cases like Tensor Parallelism. But when using the DTensor's replication in a data parallel fashion, it might become observably slower compared to our existing solutions like DDP/FSDP. This is mainly because mainly because DDP/FSDP have a global view of the entire model architecture, thus could optimize for data parallel specifically, i.e. collective fusion and computation overlap, etc. In contract, DistributedTensor as a Tensor-like object can only optimize within individual tensor operations.

To improve efficiency of DistributedTensor-based data parallel training, we are exploring a compiler-based solution on top of DistributedTensor, which can extract graph information from user programs to expose more performance optimization opportunities.
To improve efficiency of DTensor-based data parallel training, we are exploring a compiler-based solution on top of DTensor, which can extract graph information from user programs to expose more performance optimization opportunities.

## Related Works

Expand Down

0 comments on commit 09598b6

Please sign in to comment.