Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ PyObject* ToPyObject(const phi::distributed::ProcessMesh* value) {
}

PyObject* ToPyObject(const phi::distributed::Placement& value) {
auto obj = ::pybind11::cast(value);
auto obj = ::pybind11::cast(value, py::return_value_policy::reference);
obj.inc_ref();
return obj.ptr();
}
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,18 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
placements,
DenseTensorMeta(global_value->dtype(), global_value->dims()));

std::vector<int64_t> partial_dims;
size_t idx = 0;
for (auto p : placements) {
if (p->is_partial()) {
partial_dims.push_back(idx);
}
idx++;
}
TensorDistAttr dist_attr(vectorize(dist_tensor_meta_.dims()));
dist_attr.set_process_mesh(dist_tensor_meta_.process_mesh());
dist_attr.set_dims_mapping(dist_tensor_meta_.dim_mapping());
dist_attr.set_partial_status(partial_dims);
dist_attr.mark_annotated("process_mesh");
dist_attr.mark_annotated("dims_mapping");
dist_attr_ = dist_attr;
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7465,10 +7465,13 @@ def from_tensor(cls, tensor, **kwargs):
param = cls(tensor.shape, tensor.dtype, **kwargs)

# 2. transform data if needed
dist_attr = kwargs.get('dist_attr', None)
mesh = kwargs.get("process_mesh", None)
placements = kwargs.get("placements", None)
src_tensor = tensor
if dist_attr is not None:
src_tensor = core.eager.Tensor(tensor, dist_attr=dist_attr)
if mesh is not None and placements is not None:
src_tensor = core.eager.Tensor(
tensor, process_mesh=mesh, placements=placements
)

# 3. set param data
param._set_impl(src_tensor)
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@

from .auto_parallel.process_mesh import ProcessMesh

from .auto_parallel.placement_type import (
ReduceType,
Placement,
Shard,
Replicate,
Partial,
)

from .auto_parallel import shard_op # noqa: F401

from .auto_parallel.api import (
Expand Down Expand Up @@ -144,4 +152,9 @@
"dtensor_from_fn",
"reshard",
"shard_layer",
"ReduceType",
"Placement",
"Shard",
"Replicate",
"Partial",
]
95 changes: 52 additions & 43 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
)
from paddle.framework import core

from .placement_type import get_shard_spec

# There are the auto parallel API of the unified version of dynamic and static mode.
# Some APIs have the same name with the previous APIs implementation, which are
# a temporary state, and the APIs here will eventually be used.
Expand Down Expand Up @@ -92,7 +94,7 @@ def sharding_specs(self):


def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
data, mesh, placements, dtype=None, place=None, stop_gradient=True
):
"""
Constructs a ``paddle.Tensor`` with distributed attributes from ``data``,
Expand All @@ -103,6 +105,9 @@ def shard_tensor(
Args:
data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor.
Can be a scalar, list, tuple, numpy.ndarray, paddle.Tensor.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
Expand All @@ -111,7 +116,6 @@ def shard_tensor(
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.

Returns:
Tensor: A Tensor constructed from ``data`` with distributed attributes.
Expand All @@ -123,15 +127,14 @@ def shard_tensor(
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])

>>> # dense tensor
>>> a = paddle.to_tensor([[1,2,3],
... [5,6,7]])

>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Shard(0), dist.Shard(1)])

>>> print(d_tensor)

Expand All @@ -146,33 +149,34 @@ def shard_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)

# 2. create dist tensor
assert len(dist_attr.dims_mapping) == len(
list(tensor.shape)
), "The length of sharding_specs must be same as the shape of the input tensor."

if paddle.in_dynamic_mode():
# here the dist tensor is deep copy constructed
if isinstance(data, EagerParamBase):
return EagerParamBase.from_tensor(
tensor, dist_attr=dist_attr, **tensor.__dict__
tensor,
process_mesh=mesh,
placements=placements,
**tensor.__dict__
)
else:
return paddle.Tensor(tensor, dist_attr=dist_attr, place=place)
return paddle.Tensor(
tensor, process_mesh=mesh, placements=placements, place=place
)
else:
# TODO(zhiqiu): we need to refine the static shard_tensor
return shard_tensor_static(
tensor, dist_attr.process_mesh, dist_attr.sharding_specs
)
sharding_specs = get_shard_spec(mesh, placements, tensor.ndim)
return shard_tensor_static(tensor, mesh, sharding_specs)


def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
def dtensor_from_fn(fn, mesh, placements, *args, **kwargs):
"""
Construct a Distributed Tensor from a function of arguments.

Args:
fn (callable): A callable function that takes arguments of Distributed Tensor and returns tensor.
dist_attr (paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.
*args (tuple): A tuple of arguments to be passed to the ``fn`` function.
**kwargs (dict): A dict of arguments to be passed to the ``fn`` function.

Expand All @@ -186,26 +190,27 @@ def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
>>> import paddle.distributed as dist
>>> # Create a distributed attribute
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None])
>>> # Call the function dtensor_from_fn with dist_attr parameter
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, dist_attr=dist_attr, shape=[1])
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, mesh, [dist.Replicate()], shape=[1])
>>> print(d_tensor)

"""
tensor = fn(*args, **kwargs)
return shard_tensor(tensor, dist_attr=dist_attr)
return shard_tensor(tensor, mesh, placements)


# Part3: Data conversion related APIs


def reshard(dist_tensor, dist_attr):
def reshard(dist_tensor, mesh, placements):
"""
Reshard a distributed ``paddle.Tensor`` with given distributed attributes.

Args:
dist_tensor(Tensor): the distributed tensor to be resharded.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.

Returns:
Tensor: A Distributed Tensor reshared with distributed attributes.
Expand All @@ -216,28 +221,33 @@ def reshard(dist_tensor, dist_attr):
>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])

>>> out_mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> out_dist_attr = dist.DistAttr(mesh=out_mesh, sharding_specs=[None, None])
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

>>> # dense tensor
>>> a = paddle.to_tensor([[1,2,3],
... [5,6,7]])
>>> a = paddle.ones([10, 20])

>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Partial()])

>>> out_d_tensor = dist.reshard(d_tensor, out_dist_attr)
>>> out_d_tensor = dist.reshard(d_tensor, mesh, [dist.Replicate()])

>>> print(d_tensor)
>>> print(out_d_tensor)

"""

if paddle.framework.in_dynamic_mode():
# TODO(LiYuRio): static logic here, reshard should be changed for dygraph logic
# when reshard has been changed align dygraph logic, delete it.
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
dist_attr = DistAttr(mesh, sharding_specs)
partial_dims = []
for i, p in enumerate(placements):
if isinstance(p, dist.Partial):
partial_dims.append(i)
if len(partial_dims) > 0:
dist_attr._set_partial_dims(partial_dims)

return paddle.base.core.reshard(dist_tensor, dist_attr)
else:
# TODO(GhostScreaming): Support static DistTensor later.
Expand Down Expand Up @@ -312,9 +322,8 @@ def output_fn(outputs, process_mesh) -> list(paddle.Tensor)
... return self.fc2(self.fc1(input))

>>> def shard_fn(layer_name, layer, process_mesh):
... dist_attr = dist.DistAttr(mesh=process_mesh, sharding_specs=['x', None])
... if layer_name == 'fc1':
... layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr)
... layer.weight = dist.shard_tensor(layer.weight, process_mesh, [dist.Shard(0)])

>>> layer = MLP()
>>> layer = dist.shard_layer(layer, mesh, shard_fn)
Expand All @@ -339,26 +348,26 @@ def replicate_layer_params_and_buffers(
) -> None:
for key, param in layer._parameters.items():
if param is not None and not param.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(param.shape))],
)
placements = [
paddle.distributed.Replicate()
for _ in range(len(param.shape))
]
layer.add_parameter(
key,
shard_tensor(param, dist_attr=replicated_dist_attr),
shard_tensor(param, mesh, placements),
)
else:
# do nothing, the dist parameters has already been shard by shard_fn
pass
for key, buffer in layer._buffers.items():
if buffer is not None and not buffer.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(buffer.shape))],
)
placements = [
paddle.distributed.Replicate()
for _ in range(len(buffer.shape))
]
layer.register_buffer(
key,
shard_tensor(buffer, dist_attr=replicated_dist_attr),
shard_tensor(buffer, mesh, placements),
)
else:
# do nothing, the dist buffers has already been shard by shard_fn
Expand Down
89 changes: 89 additions & 0 deletions python/paddle/distributed/auto_parallel/placement_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast

from paddle.base.core import Partial, Placement, ReduceType, Replicate, Shard

__all__ = ["ReduceType", "Placement", "Replicate", "Shard", "Partial"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. ReduceType 和其他类型命名方式不太一样,多了个Type的后缀,是否需要统一?
  2. placement_type 包含 Placement,这两个概念有包含关系,Placement是一种什么样的placement_type呢?
  3. 这几个API不需要放到__all__列表,根据调用方式,只放到paddle.distributed.init.py的__all__列表就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下一个PR修改



def to_placements(dim_map, mesh, partial_idx=[]):
"""
convert dim_map to placements.

Args:
dim_map(List[int]): a list of integer that represents sharding on each tensor dimension.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
partial_idx(List[int], Optional): a list of integer that represents the DTensor have pending sum on which device mesh dimension

Returns:
List[Placement]: a list contains some `paddle.distributed.Placement`.
"""
placements = [Replicate() for _ in range(len(mesh.mesh.shape))]

for s in partial_idx:
placements[s] = Partial()

for i, m in enumerate(dim_map):
if m >= 0:
p = placements[m]
if p.is_shard():
p = cast(Shard, p)
raise Exception(
f"ProcessMesh dimension can not be mapped to two dimension of same tensor: {i} and {p.get_dim()}."
)
elif p.is_partial():
raise Exception(
f"ProcessMesh dimension {m} can not be both shard and partial!"
)
placements[m] = Shard(i)

return placements


def to_dim_map(placements, tensor_dims):
"""
convert placements to dim_map.

Args:
placements(List[Placement]): a list contains some `paddle.distributed.Placement`.
tensor_dims(int): the dimension of dist_tensor.

Returns:
List[int]: a list of integer that represents sharding on each tensor dimension.
"""
dim_map = [-1] * tensor_dims
for i, placement in enumerate(placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).get_dim()
if dim_map[shard_dim] > -1:
raise Exception(
"Tensor dim {shard_dim} is already sharded on mesh dim {dim_map[shard_dim]}"
)

dim_map[shard_dim] = i

return dim_map


def get_shard_spec(mesh, placements, tensor_dims):
"""to get shard_spec for construct DistAttr for static API."""
dim_map = to_dim_map(placements, tensor_dims)
mesh_dim_names = mesh.dim_names
shard_spec = [None] * len(dim_map)
for i, d in enumerate(dim_map):
if d > -1:
shard_spec[i] = mesh_dim_names[d]

return shard_spec
Loading