Skip to content

Commit

Permalink
add flatten_grad spmd rule and modify flatten spmd rule to call spmd …
Browse files Browse the repository at this point in the history
…rule of reshape (#64723)

* add flatten_grad spmd rule and modify flatten spmd rule to call spmd rule of reshape

* clean up the code and remove unnecessary annotation

* add the unittest case of shard dimension in the middle range of flatten

* update unittest case of test_flatten_rule
  • Loading branch information
jeff41404 authored May 31, 2024
1 parent f693970 commit adad52a
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 38 deletions.
52 changes: 24 additions & 28 deletions paddle/phi/infermeta/spmd_rules/flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/dim_trans.h"
#include "paddle/phi/infermeta/spmd_rules/reshape.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace phi {
Expand Down Expand Up @@ -105,41 +106,31 @@ SpmdInfo FlattenInferSpmd(const DistMetaTensor& x,
x_ndim,
x_dims_mapping.size()));

// Step1: Build the transformation from
// the original shape to the target shape

// obtain target shape and use ReshapeInferSpmdDynamic to infer
start_axis = PreprocessAxis(start_axis, x_ndim);
stop_axis = PreprocessAxis(stop_axis, x_ndim);
std::vector<std::shared_ptr<DimTrans>> trans =
MakeFlattenDimTrans(src_shape, start_axis, stop_axis);

// Step2: Infer the dims mapping of input (if reshard is
// needed) and output from the dimension transformation.
std::vector<std::vector<int64_t>> dims_mapping_vec =
InferFromDimTrans(x, trans);

// Step3: Update the dist attributes of input
// and output with the inferred dims mapping.
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
std::vector<int64_t> dst_shape;
int64_t flatten_size = 1;
for (int64_t i = 0; i < x_ndim; i++) {
if (i < start_axis || i > stop_axis) {
dst_shape.emplace_back(src_shape[i]);
} else {
flatten_size *= src_shape[i];
if (i == stop_axis) {
dst_shape.emplace_back(flatten_size);
}
}
}

VLOG(4) << "FlattenInferSpmd: X shape: [" << str_join(src_shape) << "]";
VLOG(4) << "Start_axis: " << start_axis;
VLOG(4) << "Stop_axis: " << start_axis;
VLOG(4) << "Transformation from input to output:";
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) {
std::shared_ptr<DimTrans> t = trans[i];
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string();
}
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";

return {{x_dist_attr_dst}, {out_dist_attr}};
VLOG(4) << "Stop_axis: " << stop_axis;
VLOG(4) << "FlattenInferSpmd: output shape: [" << str_join(dst_shape) << "]";
VLOG(4) << "use ReshapeInferSpmdDynamic to infer distributed attribute";
return ReshapeInferSpmdDynamic(x, dst_shape);
}

// TODO(jeff41404): consider xshape and use ReshapeInferSpmdReverse in future
SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
int start_axis,
Expand Down Expand Up @@ -198,5 +189,10 @@ SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x,
return {{x_dist_attr}, {out_dist_attr_dst}};
}

SpmdInfo FlattenGradInferSpmd(const DistMetaTensor& xshape,
const DistMetaTensor& out_grad) {
return ReshapeGradInferSpmd(xshape, out_grad);
}

} // namespace distributed
} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/spmd_rules/flatten.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,8 @@ SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
int start_axis,
int stop_axis);

SpmdInfo FlattenGradInferSpmd(const DistMetaTensor& xshape,
const DistMetaTensor& out_grad);
} // namespace distributed
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,7 @@
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape, out_grad]
spmd_rule : FlattenGradInferSpmd
kernel :
func : flatten_grad
data_type : out_grad
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,7 @@
output : Tensor(out), Tensor(xshape)
infer_meta :
func : FlattenWithXShapeInferMeta
spmd_rule : FlattenInferSpmd
kernel :
func : flatten
data_type : x
Expand Down
47 changes: 37 additions & 10 deletions test/auto_parallel/spmd_rules/test_flatten_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def setUp(self):

def test_flatten_infer_forward(self):
# shape: [8, 16, 8, 24] --> [8, 16 * 8, 24]
# dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, 1] [ 0, -1, 1]
# dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, 1], ([0, -1, 1], [-1, 0, -1, -1, 1] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1])
self.attrs['start_axis'] = 1
self.attrs['stop_axis'] = 2
Expand All @@ -51,14 +51,17 @@ def test_flatten_infer_forward(self):
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 2)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, 0, -1, -1, 1]
)

# shape: [8, 16, 8, 24] --> [8, 16 * 8, 24]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [ -1, 0, 1]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] ([ -1, 0, 1], [-1, -1, 0, -1, 1] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
self.attrs['start_axis'] = 1
self.attrs['stop_axis'] = 2
Expand All @@ -74,9 +77,12 @@ def test_flatten_infer_forward(self):
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, 0, -1, 1]
)

# shape: [8, 16, 8, 24] --> [8, 16 * 8, 24]
# dims_mapping: [-1, -1, 1, 0] --> [-1, -1, -1, 0] [ -1, -1, 0]
# dims_mapping: [-1, -1, 1, 0] --> [-1, -1, -1, 0] ([ -1, -1, 0], [-1, -1, -1, -1, 0] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, 0])
self.attrs['start_axis'] = 1
self.attrs['stop_axis'] = 2
Expand All @@ -92,9 +98,12 @@ def test_flatten_infer_forward(self):
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1, -1, 0]
)

# shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24]
# dims_mapping: [-1, 0, 1, -1] --> [-1, -1, -1, -1] [ -1]
# dims_mapping: [-1, 0, 1, -1] --> [-1, -1, -1, -1] ([ -1], [-1, -1, -1, -1, -1] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1])
self.attrs['start_axis'] = 0
self.attrs['stop_axis'] = -1
Expand All @@ -110,9 +119,12 @@ def test_flatten_infer_forward(self):
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1, -1, -1]
)

# shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24]
# dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, -1] [ 0]
# dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, -1] ([ 0], [-1, 0, -1, -1, -1] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1])
self.attrs['start_axis'] = 0
self.attrs['stop_axis'] = -1
Expand All @@ -128,9 +140,12 @@ def test_flatten_infer_forward(self):
infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, 0, -1, -1, -1]
)

# shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24]
# dims_mapping: [1, 0, -1, -1] --> [1, -1, -1, -1] [ 1]
# dims_mapping: [1, 0, -1, -1] --> [1, -1, -1, -1] ([ 1], [-1, 1, -1, -1, -1] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([1, 0, -1, -1])
self.attrs['start_axis'] = 0
self.attrs['stop_axis'] = -1
Expand All @@ -146,9 +161,12 @@ def test_flatten_infer_forward(self):
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, 1, -1, -1, -1]
)

# shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24]
# dims_mapping: [-1, -1, 0, 1] --> [-1, -1, -1, -1] [-1, -1]
# dims_mapping: [-1, -1, 0, 1] --> [-1, -1, -1, -1] ([-1, -1], [-1, -1, -1, -1, -1] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1])
self.attrs['start_axis'] = 1
self.attrs['stop_axis'] = -1
Expand All @@ -164,9 +182,12 @@ def test_flatten_infer_forward(self):
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1, -1, -1]
)

# shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, -1] [-1, 0]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, -1] ([-1, 0], [-1, -1, 0, -1, -1] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
self.attrs['start_axis'] = 1
self.attrs['stop_axis'] = -1
Expand All @@ -182,9 +203,12 @@ def test_flatten_infer_forward(self):
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, 0, -1, -1]
)

# shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24]
# dims_mapping: [0, 1, -1, -1] --> [0, 1, -1, -1] [0, 1]
# dims_mapping: [0, 1, -1, -1] --> [0, 1, -1, -1] ([0, 1], [-1, 0, 1, -1, -1] // xshape)
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
self.attrs['start_axis'] = 1
self.attrs['stop_axis'] = -1
Expand All @@ -200,6 +224,9 @@ def test_flatten_infer_forward(self):
infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, 0, 1, -1, -1]
)

def test_flatten_infer_backward(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
Expand Down
64 changes: 64 additions & 0 deletions test/cpp/auto_parallel/spmd_rule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,70 @@ TEST(CumSumGradInferSpmd, Ctor) {
std::vector<int64_t>({-1, -1, -1}));
}

TEST(Flatten, Ctor) {
std::vector<int64_t> mesh_shape = {2, 2};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x", "y"};
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);

auto build_input = [&](const std::vector<int64_t>& shape,
const std::vector<int64_t>& dim_mapping) {
auto t_dist_attr = TensorDistAttr();
t_dist_attr.set_process_mesh(process_mesh);
t_dist_attr.set_dims_mapping(dim_mapping);
t_dist_attr.set_dynamic_dims(std::vector<bool>(shape.size(), false));
auto input =
phi::distributed::DistMetaTensor(common::make_ddim(shape), t_dist_attr);
return input;
};

// [b, h/ph, w/pw, c, ph, pw]; dp
auto input1 = build_input({4, 16, 16, 4, 2, 2}, {0, -1, -1, -1, -1, -1});
// [b, h/ph, w/pw, c, ph, pw] => [b, h/ph, w/pw, hidden_size]
auto spmd1 = FlattenInferSpmd(input1, -3, -1);
EXPECT_EQ(spmd1.first.size(), static_cast<size_t>(1));
EXPECT_EQ(spmd1.second.size(), static_cast<size_t>(2));
check_dim_mapping(spmd1.first[0], {0, -1, -1, -1, -1, -1});
check_dim_mapping(spmd1.second[0], {0, -1, -1, -1});
check_dim_mapping(spmd1.second[1], {-1, 0, -1, -1, -1, -1, -1}); // x_shape

// [b, h/ph, w/pw, c, ph, pw]; dp, mp
auto input2 = build_input({4, 16, 16, 4, 2, 2}, {-1, 0, -1, 1, -1, -1});
auto spmd2 = FlattenInferSpmd(input2, 1, 4);
EXPECT_EQ(spmd2.first.size(), static_cast<size_t>(1));
EXPECT_EQ(spmd2.second.size(), static_cast<size_t>(2));
check_dim_mapping(spmd2.first[0], {-1, 0, -1, -1, -1, -1});
check_dim_mapping(spmd2.second[0], {-1, 0, -1});
check_dim_mapping(spmd2.second[1], {-1, -1, 0, -1, -1, -1, -1}); // x_shape

// [b, s, nh, h/nh]; dp , mp
auto input3 = build_input({2, 1024, 32, 32}, {0, -1, 1, -1});
// [b, s, nh, h/nh] => [b, s, h]
auto spmd3 = FlattenInferSpmd(input3, 2, 3);
EXPECT_EQ(spmd3.first.size(), static_cast<size_t>(1));
EXPECT_EQ(spmd3.second.size(), static_cast<size_t>(2));
check_dim_mapping(spmd3.first[0], {0, -1, 1, -1});
check_dim_mapping(spmd3.second[0], {0, -1, 1});
check_dim_mapping(spmd3.second[1], {-1, 0, -1, 1, -1}); // x_shape

// [b, c, d, h, w]; dp, mp
auto input4 = build_input({4, 16, 16, 4, 16}, {-1, -1, 0, 1, -1});
auto spmd4 = FlattenInferSpmd(input4, 1, 4);
EXPECT_EQ(spmd4.first.size(), static_cast<size_t>(1));
EXPECT_EQ(spmd4.second.size(), static_cast<size_t>(2));
check_dim_mapping(spmd4.first[0], {-1, -1, -1, -1, -1});
check_dim_mapping(spmd4.second[0], {-1, -1});
check_dim_mapping(spmd4.second[1], {-1, -1, -1, -1, -1, -1}); // x_shape

auto out_grad = build_input({2, 1024, 1024}, {-1, -1, -1});
auto xshape = build_input({0, 2, 1024, 4, 1024 / 4}, {-1, 0, 1, -1, -1});
auto spmd_grad = FlattenGradInferSpmd(xshape, out_grad);
EXPECT_EQ(spmd_grad.first.size(), static_cast<size_t>(1));
EXPECT_EQ(spmd_grad.second.size(), static_cast<size_t>(1));
check_dim_mapping(spmd_grad.first[0], {0, 1, -1});
check_dim_mapping(spmd_grad.second[0], {0, 1, -1, -1});
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle

0 comments on commit adad52a

Please sign in to comment.