Skip to content

Commit 98e024a

Browse files
zhiqiuco63oc
authored andcommitted
refine pir dist to_static to support master weight (PaddlePaddle#65089)
* refine pir dist to_static * fix bug * fix partial
1 parent 4864404 commit 98e024a

File tree

10 files changed

+118
-89
lines changed

10 files changed

+118
-89
lines changed

paddle/fluid/pybind/auto_parallel_py.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -490,9 +490,10 @@ void BindAutoParallel(py::module *m) {
490490
.def(py::self == py::self) // NOLINT
491491
.def(py::self != py::self); // NOLINT
492492

493-
auto Partial = py::class_<phi::distributed::Partial,
494-
std::shared_ptr<phi::distributed::Partial>>(
495-
*m, "Partial", Placement, R"DOC(
493+
auto Partial =
494+
py::class_<phi::distributed::Partial,
495+
std::shared_ptr<phi::distributed::Partial>>(
496+
*m, "Partial", Placement, R"DOC(
496497
The `Partial` describes `Tensor` across multiple devices, this type of tensor has the same shape but only a fraction of the value, which can be further reduce (e.g. sum/min/max) to obtain dist_tensor, often used as an intermediate representation.
497498
498499
Parameters:
@@ -510,12 +511,13 @@ void BindAutoParallel(py::module *m) {
510511
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Partial()])
511512
512513
)DOC")
513-
.def(py::init<phi::ReduceType>(),
514-
py::arg("reduce_type") = phi::ReduceType::kRedSum)
515-
.def("__hash__", &phi::distributed::Partial::hash)
516-
.def("__str__", &phi::distributed::Partial::to_string)
517-
.def(py::self == py::self) // NOLINT
518-
.def(py::self != py::self); // NOLINT
514+
.def(py::init<phi::ReduceType>(),
515+
py::arg("reduce_type") = phi::ReduceType::kRedSum)
516+
.def("reduce_type", &phi::distributed::Partial::get_reduce_type)
517+
.def("__hash__", &phi::distributed::Partial::hash)
518+
.def("__str__", &phi::distributed::Partial::to_string)
519+
.def(py::self == py::self) // NOLINT
520+
.def(py::self != py::self); // NOLINT
519521

520522
g_placement_shard_pytype = reinterpret_cast<PyTypeObject *>(Shard.ptr());
521523
g_placement_replicated_pytype =

paddle/fluid/pybind/dist_static_op_function.h

Lines changed: 5 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -43,38 +43,11 @@ static PyObject *static_api_shard_tensor(PyObject *self,
4343
auto placements = CastPyArg2VectorOfPlacement(placements_obj, 2);
4444

4545
int64_t ndim = GetValueDims(input).size();
46-
std::vector<int64_t> dim_map(ndim, -1);
47-
for (size_t i = 0; i < placements.size(); i++) {
48-
auto &placement = placements[i];
49-
if (placement->is_shard()) {
50-
auto shard_dim =
51-
dynamic_cast<const phi::distributed::Shard &>(*placement).get_dim();
52-
PADDLE_ENFORCE_EQ(
53-
dim_map[shard_dim],
54-
-1,
55-
common::errors::InvalidArgument(
56-
"Tensor dim %lld is already sharded on mesh dim %lld,"
57-
" DistTensor operator implementation does not support things "
58-
"like hybrid"
59-
" sharding strategies yet (i.e. [Shard(0), Shard(0)])",
60-
shard_dim,
61-
dim_map[shard_dim]));
62-
dim_map[shard_dim] = i;
63-
}
64-
}
65-
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status;
66-
for (size_t i = 0; i < placements.size(); ++i) {
67-
auto &p = placements[i];
68-
if (p->is_partial()) {
69-
partial_status.insert(
70-
{i,
71-
dynamic_cast<phi::distributed::Partial &>(*p).get_reduce_type()});
72-
}
73-
}
46+
auto res = CvtPlacements(placements, ndim);
7447

7548
// Call ir static api
7649
auto static_api_out = paddle::dialect::shard_tensor(
77-
input, process_mesh, dim_map, partial_status);
50+
input, process_mesh, std::get<0>(res), std::get<1>(res));
7851

7952
return ToPyObject(static_api_out);
8053
} catch (...) {
@@ -101,38 +74,11 @@ static PyObject *static_api_reshard(PyObject *self,
10174
auto placements = CastPyArg2VectorOfPlacement(placements_obj, 2);
10275

10376
int64_t ndim = GetValueDims(input).size();
104-
std::vector<int64_t> dim_map(ndim, -1);
105-
for (size_t i = 0; i < placements.size(); i++) {
106-
auto &placement = placements[i];
107-
if (placement->is_shard()) {
108-
auto shard_dim =
109-
dynamic_cast<const phi::distributed::Shard &>(*placement).get_dim();
110-
PADDLE_ENFORCE_EQ(
111-
dim_map[shard_dim],
112-
-1,
113-
common::errors::InvalidArgument(
114-
"Tensor dim %lld is already sharded on mesh dim %lld,"
115-
" DistTensor operator implementation does not support things "
116-
"like hybrid"
117-
" sharding strategies yet (i.e. [Shard(0), Shard(0)])",
118-
shard_dim,
119-
dim_map[shard_dim]));
120-
dim_map[shard_dim] = i;
121-
}
122-
}
123-
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status;
124-
for (size_t i = 0; i < placements.size(); ++i) {
125-
auto &p = placements[i];
126-
if (p->is_partial()) {
127-
partial_status.insert(
128-
{i,
129-
dynamic_cast<phi::distributed::Partial &>(*p).get_reduce_type()});
130-
}
131-
}
77+
auto res = CvtPlacements(placements, ndim);
13278

13379
// Call ir static api
134-
auto static_api_out =
135-
paddle::dialect::reshard(input, process_mesh, dim_map, partial_status);
80+
auto static_api_out = paddle::dialect::reshard(
81+
input, process_mesh, std::get<0>(res), std::get<1>(res));
13682

13783
return ToPyObject(static_api_out);
13884
} catch (...) {

paddle/fluid/pybind/eager_utils.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2867,4 +2867,37 @@ void BindEagerUtils(PyObject* module) {
28672867
}
28682868
}
28692869

2870+
std::tuple<std::vector<int64_t>,
2871+
paddle::flat_hash_map<int64_t, phi::ReduceType>>
2872+
CvtPlacements(Placements placements, int ndim) {
2873+
std::vector<int64_t> dim_map(ndim, -1);
2874+
for (size_t i = 0; i < placements.size(); i++) {
2875+
auto& placement = placements[i];
2876+
if (placement->is_shard()) {
2877+
auto shard_dim =
2878+
dynamic_cast<const phi::distributed::Shard&>(*placement).get_dim();
2879+
PADDLE_ENFORCE_EQ(
2880+
dim_map[shard_dim],
2881+
-1,
2882+
common::errors::InvalidArgument(
2883+
"Tensor dim %lld is already sharded on mesh dim %lld,"
2884+
" DistTensor operator implementation does not support things "
2885+
"like hybrid"
2886+
" sharding strategies yet (i.e. [Shard(0), Shard(0)])",
2887+
shard_dim,
2888+
dim_map[shard_dim]));
2889+
dim_map[shard_dim] = i;
2890+
}
2891+
}
2892+
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status;
2893+
for (size_t i = 0; i < placements.size(); ++i) {
2894+
auto& p = placements[i];
2895+
if (p->is_partial()) {
2896+
partial_status.insert(
2897+
{i, dynamic_cast<phi::distributed::Partial&>(*p).get_reduce_type()});
2898+
}
2899+
}
2900+
return {dim_map, partial_status};
2901+
}
2902+
28702903
} // namespace paddle::pybind

paddle/fluid/pybind/eager_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,5 +507,9 @@ void ConvertAllInputsToDistTensor(const phi::distributed::ProcessMesh* mesh,
507507
void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh);
508508
void BindEagerUtils(PyObject* module);
509509

510+
std::tuple<std::vector<int64_t>,
511+
paddle::flat_hash_map<int64_t, phi::ReduceType>>
512+
CvtPlacements(phi::distributed::Placements placements, int ndim);
513+
510514
} // namespace pybind
511515
} // namespace paddle

paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,9 @@ const distributed::TensorDistAttr& DistMetaTensor::dist_attr() const {
4747
}
4848
}
4949

50+
bool DistMetaTensor::initialized() const {
51+
return tensor_ != nullptr || dist_attr_ != TensorDistAttr();
52+
}
53+
5054
} // namespace distributed
5155
} // namespace phi

paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class DistMetaTensor : public MetaTensor {
4848

4949
const distributed::TensorDistAttr& dist_attr() const;
5050

51+
bool initialized() const override;
52+
5153
private:
5254
/**
5355
* Note: When using the semi-automatic parallel segmentation derivation rules

python/paddle/distributed/auto_parallel/placement_type.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def to_dim_map(placements, tensor_dims):
7676
List[int]: a list of integer that represents sharding on each tensor dimension.
7777
"""
7878
dim_map = [-1] * tensor_dims
79+
partial_status = {}
7980
for i, placement in enumerate(placements):
8081
if placement.is_shard():
8182
shard_dim = cast(Shard, placement).get_dim()
@@ -85,13 +86,15 @@ def to_dim_map(placements, tensor_dims):
8586
)
8687

8788
dim_map[shard_dim] = i
89+
if placement.is_partial():
90+
partial_status[i] = cast(Partial, placement).reduce_type()
8891

89-
return dim_map
92+
return dim_map, partial_status
9093

9194

9295
def get_shard_spec(mesh, placements, tensor_dims):
9396
"""to get shard_spec for construct DistAttr for static API."""
94-
dim_map = to_dim_map(placements, tensor_dims)
97+
dim_map, _ = to_dim_map(placements, tensor_dims)
9598
mesh_dim_names = mesh.dim_names
9699
shard_spec = [None] * len(dim_map)
97100
for i, d in enumerate(dim_map):

python/paddle/jit/pir_dy2static/parameter_recorder.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,10 @@ def get(self, program, tensor):
4848
name=tensor.name,
4949
initializer=non_used_initializer,
5050
trainable=(not tensor.stop_gradient),
51+
placements=tensor.placements,
52+
process_mesh=tensor.process_mesh,
5153
)
5254

53-
if tensor.placements is not None: # import for shard tensor api
54-
import paddle.distributed as dist
55-
56-
dist_value = dist.shard_tensor(
57-
value,
58-
tensor.process_mesh,
59-
tensor.placements,
60-
stop_gradient=value.stop_gradient,
61-
)
62-
value.set_type(dist_value.type())
63-
value.get_defining_op().dist_attr = (
64-
dist_value.get_defining_op().dist_attr
65-
)
66-
dist_value.block.remove_op(dist_value.get_defining_op())
67-
6855
if isinstance(tensor, paddle.Tensor):
6956
params.add(tensor)
7057
mappings[id(tensor)] = value

python/paddle/optimizer/optimizer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -835,14 +835,27 @@ def get_param_from_startup(startup, name):
835835
startup_param = get_param_from_startup(
836836
startup_program, param.name
837837
)
838-
var = paddle.cast(startup_param, 'float32')
839-
var.persistable = True
840-
paddle._pir_ops.set_persistable_value(var, var_name)
838+
startup_var = paddle.cast(startup_param, 'float32')
839+
startup_var.persistable = True
840+
paddle._pir_ops.set_persistable_value(startup_var, var_name)
841841
with paddle.static.program_guard(main_program):
842842
paddle.pir.reset_insertion_point_to_start()
843843
var = paddle.static.data(
844-
var_name, var.shape, var.dtype, core.Place()
844+
var_name,
845+
startup_var.shape,
846+
startup_var.dtype,
847+
core.Place(),
845848
)
849+
if startup_var.is_dist():
850+
var.set_type(startup_var.type())
851+
op_dist_attr = (
852+
paddle.base.libpaddle.pir.create_op_dist_attribute(
853+
startup_var.dist_attr().process_mesh,
854+
[],
855+
[startup_var.dist_attr()],
856+
)
857+
)
858+
var.get_defining_op().dist_attr = op_dist_attr
846859
var.persistable = True
847860
elif framework.in_dygraph_mode():
848861
var = paddle.cast(param, 'float32')

python/paddle/pir/core.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,44 @@ def create_parameter(
317317
main_program = default_main_program()
318318
parameter_meta = ParameterMeta(shape, dtype)
319319

320+
is_dist = False
321+
if (
322+
'placements' in kwargs
323+
and kwargs['placements']
324+
and 'process_mesh' in kwargs
325+
and kwargs['process_mesh']
326+
):
327+
is_dist = True
328+
329+
def to_dist(value):
330+
import paddle
331+
import paddle.distributed as dist
332+
333+
process_mesh = kwargs['process_mesh']
334+
dim_map, partial_status = dist.auto_parallel.placement_type.to_dim_map(
335+
kwargs['placements'], len(shape)
336+
)
337+
dist_attr = paddle.base.libpaddle.pir.create_tensor_dist_attribute(
338+
process_mesh, dim_map, partial_status
339+
)
340+
dist_type = paddle.base.libpaddle.pir.cvt_to_dist_type(
341+
value.type(), dist_attr
342+
)
343+
value.set_type(dist_type)
344+
op_dist_attr = paddle.base.libpaddle.pir.create_op_dist_attribute(
345+
process_mesh, [], [dist_attr]
346+
)
347+
value.get_defining_op().dist_attr = op_dist_attr
348+
320349
with program_guard(startup_program):
321350
initializer = kwargs['initializer']
322351
init_result = initializer(
323352
parameter_meta, startup_program.global_block()
324353
)
325354
init_result.persistable = True
355+
if is_dist:
356+
to_dist(init_result)
357+
326358
set_parameter(init_result, value_name)
327359

328360
main_program.set_parameters_from(startup_program)
@@ -331,6 +363,9 @@ def create_parameter(
331363
param = parameter(value_name)
332364
param.persistable = True
333365

366+
if is_dist:
367+
to_dist(param)
368+
334369
param.trainable = kwargs.get('trainable', True)
335370
param.stop_gradient = not param.trainable
336371
param.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})

0 commit comments

Comments
 (0)