Skip to content

Commit

Permalink
fix spmd rule of flatten_grad, add xshape dist_attr in result (#64985)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeff41404 authored Jun 7, 2024
1 parent 60fe41b commit ed8168d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
8 changes: 7 additions & 1 deletion paddle/phi/infermeta/spmd_rules/flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,13 @@ SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x,

SpmdInfo FlattenGradInferSpmd(const DistMetaTensor& xshape,
const DistMetaTensor& out_grad) {
return ReshapeGradInferSpmd(xshape, out_grad);
// TODO(jeff41404): when ReshapeInferSpmd and ReshapeGradInferSpmd can deliver
// distributed attribute of xshape, we will use ReshapeGradInferSpmd directly
// in future return ReshapeGradInferSpmd(xshape, out_grad);
auto shape = phi::vectorize(xshape.dims());
shape = std::vector<int64_t>(shape.begin() + 1, shape.end());
const auto& spmd = ReshapeInferSpmd(out_grad, shape);
return {{xshape.dist_attr(), spmd.first[0]}, {spmd.second[0]}};
}

} // namespace distributed
Expand Down
11 changes: 6 additions & 5 deletions test/cpp/auto_parallel/spmd_rule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1908,13 +1908,14 @@ TEST(Flatten, Ctor) {
check_dim_mapping(spmd4.second[0], {-1, -1});
check_dim_mapping(spmd4.second[1], {-1, -1, -1, -1, -1, -1}); // x_shape

auto out_grad = build_input({2, 1024, 1024}, {-1, -1, -1});
auto xshape = build_input({0, 2, 1024, 4, 1024 / 4}, {-1, 0, 1, -1, -1});
auto out_grad = build_input({2, 1024, 1024}, {0, -1, 1});
auto xshape = build_input({0, 2, 1024, 4, 1024 / 4}, {-1, 0, -1, 1, -1});
auto spmd_grad = FlattenGradInferSpmd(xshape, out_grad);
EXPECT_EQ(spmd_grad.first.size(), static_cast<size_t>(1));
EXPECT_EQ(spmd_grad.first.size(), static_cast<size_t>(2));
EXPECT_EQ(spmd_grad.second.size(), static_cast<size_t>(1));
check_dim_mapping(spmd_grad.first[0], {0, 1, -1});
check_dim_mapping(spmd_grad.second[0], {0, 1, -1, -1});
check_dim_mapping(spmd_grad.first[0], {-1, 0, -1, 1, -1});
check_dim_mapping(spmd_grad.first[1], {0, -1, 1});
check_dim_mapping(spmd_grad.second[0], {0, -1, 1, -1});
}

} // namespace auto_parallel
Expand Down

0 comments on commit ed8168d

Please sign in to comment.