Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions paddle/phi/infermeta/spmd_rules/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::vector<int64_t> InferTargetShape(const std::vector<int64_t>& shape,
PADDLE_ENFORCE_EQ(
product,
len,
phi::errors::InvalidArgument("The total size are not matched"));
phi::errors::InvalidArgument("The total size are not matched."));
return std::vector<int64_t>(shape);
} else {
std::vector<int64_t> new_shape(shape);
Expand All @@ -59,7 +59,7 @@ std::vector<int64_t> InferTargetShape(const std::vector<int64_t>& shape,
PADDLE_ENFORCE_EQ(len % infer_size,
0,
phi::errors::InvalidArgument(
"The total is not diviable by infer_size"));
"The total is not diviable by infer_size."));
new_shape[infer_idx] = infer_size;
return new_shape;
}
Expand Down Expand Up @@ -143,8 +143,11 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
const std::vector<int64_t>& shape) {
// Step0: Verify input args based on reshape logic
auto src_shape = phi::vectorize(x.dims());
int x_ndim = src_shape.size();
VLOG(2) << "Debug Info for reshape";
VLOG(2) << "shape: " << str_join(shape);
auto x_shape = phi::vectorize(x.dims());
int x_ndim = x_shape.size();
int out_ndim = shape.size();
auto x_dist_attr_src = x.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
Expand All @@ -154,20 +157,31 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
"dims_mapping size [%d] are not matched.",
x_ndim,
x_dims_mapping.size()));
VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(x_shape) << "]";
VLOG(4) << "Out shape: [" << str_join(shape) << "]";

// Step1: Build the transformation from
// the original shape to the target shape

// handle the case of dynamic shape, like [-1, -1, ...] --> [0, 0, ...].
// This is used in inference but reshape allows only one '-1' in the
// target shape, so set the shape to a special value '256'
for (int i = 0; i < x_ndim; i++) {
if (x_shape[i] == -1) {
x_shape[i] = 256;
}
}

// handle the '0' values in target shape, '0' indicates
// that the target shape is equal to the source shape
std::vector<int64_t> tgt_shape(shape);
for (int64_t i = 0, n = static_cast<int64_t>(tgt_shape.size()); i < n; i++) {
for (int64_t i = 0; i < out_ndim; i++) {
if (tgt_shape[i] == 0) {
tgt_shape[i] = src_shape[i];
tgt_shape[i] = x_shape[i];
}
}

std::vector<DimTrans*> trans = MakeReshapeDimTrans(src_shape, tgt_shape);
std::vector<DimTrans*> trans = MakeReshapeDimTrans(x_shape, tgt_shape);

// Step2: Infer the dims mapping of input (if reshard is
// needed) and output from the dimension transformation.
Expand All @@ -181,17 +195,14 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
TensorDistAttr out_dist_attr(x_dist_attr_src);
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(src_shape)
<< "] Out shape: [" << str_join(tgt_shape) << "]";
VLOG(4) << "Transformation from input to output:";
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) {
DimTrans* t = trans[i];
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string();
}
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0])
<< "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "]\n\n";
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";

CleanUp();

Expand All @@ -201,9 +212,12 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
const std::vector<int64_t>& shape) {
VLOG(2) << "Debug Info for reshape_reverse";
VLOG(2) << "shape: " << str_join(shape);
// Step0: Verify input args based on reshape logic
auto x_shape = phi::vectorize(x.dims());
auto out_shape = phi::vectorize(out.dims());
int x_ndim = x_shape.size();
int out_ndim = out_shape.size();
auto out_dist_attr_src = out.dist_attr();
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
Expand All @@ -214,14 +228,39 @@ SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x,
"dims_mapping size [%d] are not matched.",
out_ndim,
out_dims_mapping.size()));
VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape)
<< "], X shape: [" << str_join(x_shape) << "]";

// Step1: Build the transformation from the output shape
// to original shape. This function infers the dims mapping
// from output to input, we first get the transformation
// from output to input so that we can infer the dims mapping
// with the map from output axes to input axes.
// Shapes in InferSpmdReverse don't contain -1 or 0, so they will
// not be modified and we can directly use them.

// handle the case of dynamic shape, like [-1, -1, ...] --> [0, 0, ...].
// This is used in inference but reshape allows only one '-1' in the
// target shape, so set the shape to a special value '256'
for (int i = 0; i < x_ndim; i++) {
if (x_shape[i] == -1) {
x_shape[i] = 256;
}
}

// handle the '0' values in target shape, '0' indicates
// that the target shape is equal to the source shape
std::vector<int64_t> tgt_shape(shape);
for (int64_t i = 0; i < out_ndim; i++) {
if (shape[i] == 0) {
out_shape[i] = x_shape[i];
}
}

// The out_shape may contain '-1', which will cause error
// when inferring the transformation from out_shape to
// x_shape, so infer the '-1' value before inferrng DimTrans
int64_t nelm = std::accumulate(
x_shape.begin(), x_shape.end(), 1, std::multiplies<int64_t>());
out_shape = InferTargetShape(out_shape, nelm);
std::vector<DimTrans*> trans = MakeReshapeDimTrans(out_shape, x_shape);

// Step2: Infer the dims mapping of input with
Expand All @@ -236,8 +275,6 @@ SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x,
TensorDistAttr x_dist_attr(x.dist_attr());
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape)
<< "] X shape: [" << str_join(x_shape) << "]";
VLOG(4) << "Transformation from output to input:";
for (int64_t i = 0, n = trans.size(); i < n; i++) {
DimTrans* t = trans[i];
Expand Down
105 changes: 105 additions & 0 deletions test/auto_parallel/spmd_rules/test_reshape_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,54 @@ def test_reshape_infer_forward(self):
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192]
# dims_mapping: [0, 1, -1] --> [0, 1, -1], [0, 1, -1, -1]
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.attrs["shape"] = [0, 0, -1, 192]
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['shape']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192]
# dims_mapping: [0, -1, 1] --> [0, -1, -1], [0, -1, -1, -1]
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.attrs["shape"] = [0, 0, -1, 192]
self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1])
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['shape']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192]
# dims_mapping: [1, -1, 0] --> [1, -1, 0], [1, -1, 0, -1]
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.attrs["shape"] = [0, 0, -1, 192]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0])
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['shape']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1]
)

# shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1]
# raise error
self.attrs["shape"] = [3, 24, 6, -1, -1]
Expand Down Expand Up @@ -454,6 +502,63 @@ def test_reshape_infer_backward(self):
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, 0]
)

# shape: [8, 1024, 3072] --> [0, 0, -1, 192] (input --> output)
# dims_mapping: [0, 1, -1, -1] --> [0, 1, -1], [0, 1, -1, -1] (output --> input, output)
self.x_dist_tensor_spec.shape = [8, 1024, 3072]
self.output_dist_tensor_spec.shape = [0, 0, -1, 192]
self.attrs["shape"] = [0, 0, -1, 192]
self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['shape'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192] (input --> output)
# dims_mapping: [0, 1, -1, -1] --> [0, 1, -1], [0, 1, -1, -1] (output --> input, output)
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.output_dist_tensor_spec.shape = [0, 0, -1, 192]
self.attrs["shape"] = [0, 0, -1, 192]
self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['shape'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192] (input --> output)
# dims_mapping: [0, -1, 1, -1] --> [0, -1, 1], [0, -1, 1, -1] (output --> input, output)
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.output_dist_tensor_spec.shape = [0, 0, -1, 192]
self.attrs["shape"] = [0, 0, -1, 192]
self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1])
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['shape'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, -1, 1, -1]
)


if __name__ == "__main__":
unittest.main()