Skip to content

Commit ae0fe07

Browse files
zhiqiu0x45fwinter-wang
authored
Fit PIR AMP for auto_parallel (#65892)
* [Test]Support pir amp for dist * Refine code * refine pir dist to_static * fix bug * fix partial * Fix dist engine code * fit pir grad_scaler with auto_parallel * use amp strategy * update ut * update ut * fit for amp o1 * revert changes of grad_scaler * fix ut and refine code --------- Co-authored-by: 0x45f <wangzhen45@baidu.com> Co-authored-by: winter-wang <1030748926@qq.com>
1 parent a8cdd97 commit ae0fe07

File tree

22 files changed

+455
-82
lines changed

22 files changed

+455
-82
lines changed

paddle/fluid/pir/dialect/distributed/ir/dist_op.cc

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h"
1616
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
17+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
1718
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
1819
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
1920
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
@@ -273,30 +274,6 @@ void ReshardOp::VerifySig() {
273274
VLOG(4) << "End Verifying for: ShardTensorOp.";
274275
}
275276

276-
ProcessMeshAttribute MergeMeshes(const ProcessMeshAttribute& mesh1,
277-
const ProcessMeshAttribute& mesh2) {
278-
if (mesh1 == mesh2) return mesh1;
279-
// Combine the two ids
280-
std::vector<int64_t> merged_ids;
281-
std::vector<int64_t> ids1 = mesh1.process_ids();
282-
std::vector<int64_t> ids2 = mesh2.process_ids();
283-
284-
merged_ids.reserve(ids1.size() + ids2.size());
285-
merged_ids.insert(merged_ids.end(), ids1.begin(), ids1.end());
286-
merged_ids.insert(merged_ids.end(), ids2.begin(), ids2.end());
287-
288-
// Remove duplicates
289-
std::sort(merged_ids.begin(), merged_ids.end());
290-
auto last = std::unique(merged_ids.begin(), merged_ids.end());
291-
merged_ids.erase(last, merged_ids.end());
292-
293-
return ProcessMeshAttribute::get(
294-
pir::IrContext::Instance(),
295-
{static_cast<int64_t>(merged_ids.size())}, // flatten mesh shape
296-
merged_ids,
297-
{"merged"});
298-
}
299-
300277
void ReshardOp::Build(pir::Builder& builder,
301278
pir::OperationArgument& argument,
302279
pir::Value input,

paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,152 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
16+
17+
#include <unordered_set>
18+
1619
#include "glog/logging.h"
1720
#include "paddle/common/enforce.h"
21+
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
1822
#include "paddle/pir/include/core/operation.h"
1923

2024
namespace paddle::dialect {
2125

26+
ProcessMeshAttribute MergeMeshes(const ProcessMeshAttribute& mesh1,
27+
const ProcessMeshAttribute& mesh2) {
28+
if (mesh1 == mesh2) return mesh1;
29+
// Combine the two ids
30+
std::vector<int64_t> merged_ids;
31+
std::vector<int64_t> ids1 = mesh1.process_ids();
32+
std::vector<int64_t> ids2 = mesh2.process_ids();
33+
34+
merged_ids.reserve(ids1.size() + ids2.size());
35+
merged_ids.insert(merged_ids.end(), ids1.begin(), ids1.end());
36+
merged_ids.insert(merged_ids.end(), ids2.begin(), ids2.end());
37+
38+
// Remove duplicates
39+
std::sort(merged_ids.begin(), merged_ids.end());
40+
auto last = std::unique(merged_ids.begin(), merged_ids.end());
41+
merged_ids.erase(last, merged_ids.end());
42+
43+
return ProcessMeshAttribute::get(
44+
pir::IrContext::Instance(),
45+
{static_cast<int64_t>(merged_ids.size())}, // flatten mesh shape
46+
merged_ids,
47+
{"merged"});
48+
}
49+
50+
ProcessMeshAttribute MergeInputMeshes(const std::vector<pir::Value>& inputs) {
51+
auto ctx = pir::IrContext::Instance();
52+
auto mesh = ProcessMeshAttribute::get(ctx, {}, {}, {});
53+
for (auto value : inputs) {
54+
if (auto dist_type = value.type().dyn_cast<DistTypeInterface>()) {
55+
mesh = MergeMeshes(mesh, dist_type.process_mesh_attr());
56+
} else {
57+
auto vec_type = value.type().dyn_cast<pir::VectorType>();
58+
if (!vec_type) {
59+
continue;
60+
}
61+
for (size_t idx = 0; idx < vec_type.size(); ++idx) {
62+
if (auto dist_type = vec_type[idx].dyn_cast<DistTypeInterface>()) {
63+
mesh = MergeMeshes(mesh, dist_type.process_mesh_attr());
64+
}
65+
}
66+
}
67+
}
68+
return mesh;
69+
}
70+
71+
ProcessMeshAttribute CreateGlobalMesh(const std::vector<pir::Value>& inputs) {
72+
auto ctx = pir::IrContext::Instance();
73+
struct MyHash {
74+
std::size_t operator()(const ProcessMeshAttribute& obj) const {
75+
return obj.hash();
76+
}
77+
};
78+
std::unordered_set<ProcessMeshAttribute, MyHash> meshes;
79+
for (auto value : inputs) {
80+
if (auto dist_type = value.type().dyn_cast<DistTypeInterface>()) {
81+
meshes.insert(dist_type.process_mesh_attr());
82+
} else {
83+
if (auto vec_type = value.type().dyn_cast<pir::VectorType>()) {
84+
for (size_t idx = 0; idx < vec_type.size(); ++idx) {
85+
if (auto dist_type = vec_type[idx].dyn_cast<DistTypeInterface>()) {
86+
meshes.insert(dist_type.process_mesh_attr());
87+
}
88+
}
89+
}
90+
}
91+
}
92+
93+
ProcessMeshAttribute global_mesh;
94+
PADDLE_ENFORCE_GT(meshes.size(),
95+
0,
96+
common::errors::InvalidArgument("There is no dist input"));
97+
// get mesh that has the most dimensions
98+
auto max_ndim_mesh = ProcessMeshAttribute::get(ctx, {}, {}, {});
99+
int64_t min_ndim = std::numeric_limits<int64_t>::max();
100+
for (const auto& mesh : meshes) {
101+
if (mesh.ndim() > max_ndim_mesh.ndim()) {
102+
max_ndim_mesh = mesh;
103+
}
104+
if (mesh.ndim() < min_ndim) {
105+
min_ndim = mesh.ndim();
106+
}
107+
}
108+
// min != max, means there are different mesh size
109+
// so, the max_ndim_mesh should be the global mesh
110+
if (min_ndim != max_ndim_mesh.ndim()) {
111+
for (const auto& mesh : meshes) {
112+
if (mesh != max_ndim_mesh) {
113+
if (!phi::distributed::IsSubMesh(max_ndim_mesh.process_mesh(),
114+
mesh.process_mesh())) {
115+
PADDLE_THROW(common::errors::InvalidArgument(
116+
"The small mesh should be the sub mesh of the large mesh, but "
117+
"got {%s} vs {%s} ",
118+
mesh,
119+
max_ndim_mesh));
120+
}
121+
}
122+
}
123+
global_mesh = max_ndim_mesh;
124+
} else {
125+
auto it = meshes.begin();
126+
auto first_mesh = *it;
127+
if (meshes.size() > 1) {
128+
auto global_ids = first_mesh.process_ids();
129+
auto global_shape = first_mesh.shape();
130+
auto global_names = first_mesh.dim_names();
131+
++it;
132+
for (; it != meshes.end(); ++it) {
133+
auto mesh = *it;
134+
VLOG(4) << (mesh.shape() == first_mesh.shape()) << " "
135+
<< (mesh.dim_names() == first_mesh.dim_names()) << " "
136+
<< (mesh.process_ids() != first_mesh.process_ids());
137+
if (mesh.shape() == first_mesh.shape() &&
138+
mesh.dim_names() == first_mesh.dim_names() &&
139+
mesh.process_ids() != first_mesh.process_ids()) {
140+
global_ids.insert(global_ids.end(),
141+
mesh.process_ids().begin(),
142+
mesh.process_ids().end());
143+
} else {
144+
PADDLE_THROW(common::errors::InvalidArgument(
145+
"The sub meshes should have same shape and names but different "
146+
"process_ids, but got {%s} vs {%s} ",
147+
first_mesh,
148+
mesh));
149+
}
150+
}
151+
global_shape.emplace(global_shape.begin(), meshes.size());
152+
global_names.emplace(global_names.begin(), "global");
153+
global_mesh = ProcessMeshAttribute::get(
154+
ctx, global_shape, global_ids, global_names);
155+
} else {
156+
global_mesh = first_mesh;
157+
}
158+
}
159+
return global_mesh;
160+
}
161+
22162
bool AllInputAreDist(const std::vector<pir::Value>& inputs) {
23163
for (auto value : inputs) {
24164
auto type = value.type();
@@ -210,6 +350,22 @@ void CopyLeafOpToMesh(pir::Value value, ProcessMeshAttribute mesh_attr) {
210350
if (op->num_operands() != 0u || op->num_results() != 1u) {
211351
return;
212352
}
353+
if (mesh_attr.ndim() > 1 &&
354+
phi::distributed::IsSubMesh(
355+
mesh_attr.process_mesh(),
356+
dist_type.process_mesh_attr().process_mesh())) {
357+
auto new_dist_type = dist_type.CopyWithNewMesh(mesh_attr);
358+
value.set_type(new_dist_type);
359+
op->set_attribute(
360+
kAttrOpDistAttr,
361+
OperationDistAttribute::get(new_dist_type.ir_context(),
362+
mesh_attr,
363+
{},
364+
{new_dist_type.tensor_dist_attr()}));
365+
VLOG(4) << "CopyLeafOpToMesh: change mesh from "
366+
<< dist_type.process_mesh_attr() << " to " << mesh_attr;
367+
return;
368+
}
213369
pir::IrMapping ir_mapping;
214370
auto new_op = op->Clone(ir_mapping);
215371
op->GetParent()->insert(*op, new_op);
@@ -222,6 +378,8 @@ void CopyLeafOpToMesh(pir::Value value, ProcessMeshAttribute mesh_attr) {
222378
mesh_attr,
223379
{},
224380
{dist_type.tensor_dist_attr()}));
381+
VLOG(4) << "CopyLeafOpToMesh: copy value from "
382+
<< dist_type.process_mesh_attr() << " to " << mesh_attr;
225383
}
226384
}
227385
}

paddle/fluid/pir/dialect/distributed/ir/dist_tools.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
namespace paddle {
2222
namespace dialect {
2323

24+
ProcessMeshAttribute MergeMeshes(const ProcessMeshAttribute& mesh1,
25+
const ProcessMeshAttribute& mesh2);
26+
27+
ProcessMeshAttribute MergeInputMeshes(const std::vector<pir::Value>& inputs);
28+
29+
ProcessMeshAttribute CreateGlobalMesh(const std::vector<pir::Value>& inputs);
30+
2431
bool HasDistInput(const std::vector<pir::Value>& inputs,
2532
ProcessMeshAttribute* p_mesh_attr = nullptr);
2633
bool AllInputAreDist(const std::vector<pir::Value>& inputs);

paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def GenDistBranch(args, op_info):
789789
// Auto Parallel condition
790790
ProcessMeshAttribute op_mesh;
791791
if(HasDistInput(input_values, &op_mesh)) {{
792+
{}
792793
{}
793794
CvtAllInputsToDist(input_values, op_mesh);
794795
auto ctx = pir::IrContext::Instance();
@@ -799,7 +800,15 @@ def GenDistBranch(args, op_info):
799800
if name == "learning_rate":
800801
extra_call = "CopyLeafOpToMesh(learning_rate_, op_mesh);"
801802
break
802-
dist_branch_str = TEMPLATE.format(extra_call)
803+
merge_input_meshes = ""
804+
if (
805+
op_info.class_name == 'CheckFiniteAndUnscale_Op'
806+
or op_info.class_name == 'UpdateLossScaling_Op'
807+
):
808+
merge_input_meshes = "op_mesh = CreateGlobalMesh(input_values);"
809+
if op_info.class_name == 'CheckFiniteAndUnscale_Op':
810+
extra_call = "CopyLeafOpToMesh(scale_, op_mesh);"
811+
dist_branch_str = TEMPLATE.format(merge_input_meshes, extra_call)
803812
infer_spmd_args_list = []
804813
# Prepare inputs_meta_tensor & attributes for infer spmd
805814
for name in op_info.spmd_params:
@@ -844,6 +853,7 @@ def GenDistBranch(args, op_info):
844853
spmd_rule_func = "VariadicReplicatedInferSpmdDynamic"
845854
TEMPLATE = """
846855
auto spmd_info = phi::distributed::{spmd_func}({args});
856+
DebugInfoForInferSpmd("{op_name}", spmd_info);
847857
PADDLE_ENFORCE_EQ(spmd_info.first.size(), {input_size}u, common::errors::Unavailable(
848858
"Size of spmd_info.first for op[{op_name}]is unexpected."));
849859
for(auto& arg_dist : spmd_info.first) {{

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@
9393
'c_allreduce_avg_',
9494
'c_reduce_avg',
9595
'c_reduce_avg_',
96+
'c_allreduce_avg',
97+
'c_allreduce_max',
9698
'c_reducescatter',
9799
'c_allreduce_min_',
98100
'c_allreduce_prod_',
@@ -161,8 +163,6 @@
161163
'assign_pos',
162164
'batch_fc',
163165
'barrier',
164-
'c_allreduce_avg',
165-
'c_allreduce_max',
166166
'c_allreduce_min',
167167
'c_allreduce_prod',
168168
'c_embedding',

paddle/fluid/pir/dialect/operator/utils/utils.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const std::unordered_set<std::string> LegacyOpList = {
5252
CReduceSumOp::name(),
5353
CReduceSum_Op::name(),
5454
CAllreduceMax_Op::name(),
55+
CAllreduceMaxOp::name(),
5556
CAllreduceMin_Op::name(),
5657
CAllgatherOp::name(),
5758
CSoftmaxWithCrossEntropyOp::name(),

paddle/phi/api/lib/api_gen_utils.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,19 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
677677
// TODO(GhostScreaming): Inplace outputs are initialized, just set their
678678
// dist_attr.
679679
if (out->size() == out_size) {
680-
VLOG(3) << "Outputs are inplace vector Tensors, just set their dist_attrs "
681-
<< "according to InferSPMD output result.";
680+
VLOG(3) << "Outputs are inplace vector Tensors, SKIP set dist_attr for out "
681+
<< "to avoid changing the inplaced input";
682682
for (size_t i = 0; i < out_size; ++i) {
683683
results[i] =
684684
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
685-
results[i]->unsafe_set_dist_attr(dist_attrs[i]);
685+
continue;
686+
// auto t =
687+
// static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
688+
// auto dist_t = std::make_shared<phi::distributed::DistTensor>(
689+
// t->shared_value(), t->dims(), dist_attrs[i]);
690+
// out->at(i) = Tensor();
691+
// out->at(i).set_impl(dist_t);
692+
// results[i] = dist_t.get();
686693
}
687694
} else {
688695
out->reserve(out_size);

paddle/phi/api/lib/data_transform.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,9 @@ ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx,
747747
if (tensor_in) {
748748
phi::distributed::DistTensor* dist_tensor =
749749
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
750+
VLOG(4) << "ReshardIsNeededWithPartial"
751+
<< ReshardIsNeededWithPartial(dist_tensor->dist_attr(),
752+
dist_attr);
750753
if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(), dist_attr)) {
751754
auto argument_name =
752755
(arg_name.empty() ? "tensor" : arg_name) + "_" + std::to_string(i);
@@ -806,7 +809,7 @@ void SetInplaceOutputCorrectDistAttr(
806809
phi::distributed::DistTensor* dist_tensor =
807810
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
808811
if (dist_tensor->initialized()) {
809-
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
812+
if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(), dist_attr)) {
810813
if (use_general_spmd_rule) {
811814
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output"
812815
<< " to origin dist_attr "
@@ -856,7 +859,8 @@ void SetInplaceOutputCorrectDistAttr(
856859
phi::distributed::DistTensor* dist_tensor =
857860
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
858861
if (dist_tensor->initialized()) {
859-
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr[i])) {
862+
if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(),
863+
dist_attr[i])) {
860864
if (use_general_spmd_rule) {
861865
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output"
862866
<< " to origin dist_attr "

paddle/phi/core/distributed/auto_parallel/dist_tensor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ class DistTensor final
130130
/// \return The DenseTensor value's const reference
131131
const DenseTensor& value() const { return *value_; }
132132

133+
/// \brief Returns the shared_ptr of dense tensor value's in dist tensor.
134+
/// \return The shared_ptr of dense tensor value
135+
std::shared_ptr<DenseTensor> shared_value() { return value_; }
136+
133137
/// \brief Returns the mutable dense tensor value in dist tensor.
134138
/// \note If DenseTensor value is modified externally, the corresponding
135139
/// relationship between it and the current tensor's global dims and

paddle/phi/infermeta/spmd_rules/amp_ops.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <vector>
1818
#include "glog/logging.h"
1919

20+
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
2021
#include "paddle/phi/infermeta/spmd_rules/utils.h"
2122

2223
namespace phi {
@@ -26,14 +27,21 @@ SpmdInfo CheckFiniteAndUnscaleSpmd(const std::vector<DistMetaTensor>& xs,
2627
const DistMetaTensor& scale) {
2728
std::vector<TensorDistAttr> xs_attrs;
2829
paddle::flat_hash_map<int64_t, ReduceType> partial_on_dims;
30+
auto scale_mesh = scale.dist_attr().process_mesh();
31+
auto offset = 0;
2932
for (auto& x : xs) {
3033
auto dist_attr = x.dist_attr();
3134
dist_attr.clean_partial_status();
3235
xs_attrs.emplace_back(dist_attr);
3336
auto dims_mapping = dist_attr.dims_mapping();
37+
auto mesh = dist_attr.process_mesh();
38+
if (scale_mesh.ndim() > 1 && IsSubMesh(scale_mesh, mesh)) {
39+
partial_on_dims[0] = ReduceType::kRedMax;
40+
offset = 1;
41+
}
3442
for (auto& m : dims_mapping) {
3543
if (m != -1 && partial_on_dims.count(m) == 0) {
36-
partial_on_dims[m] = ReduceType::kRedMax;
44+
partial_on_dims[m + offset] = ReduceType::kRedMax;
3745
}
3846
}
3947
}
@@ -62,6 +70,7 @@ SpmdInfo UpdateLossScalingSpmd(const std::vector<DistMetaTensor>& xs,
6270
}
6371
TensorDistAttr found_infinite_attr =
6472
CopyTensorDistAttrForOutput(found_infinite.dist_attr());
73+
found_infinite_attr.set_dims_mapping({-1});
6574
return {{xs_attrs,
6675
found_infinite_attr,
6776
prev_loss_scaling.dist_attr(),

0 commit comments

Comments
 (0)