Skip to content

Commit

Permalink
Add unit test code
Browse files Browse the repository at this point in the history
  • Loading branch information
WintersMontagne10335 committed Oct 6, 2023
1 parent 9e9140f commit 6c2f23f
Show file tree
Hide file tree
Showing 3 changed files with 525 additions and 20 deletions.
5 changes: 1 addition & 4 deletions paddle/phi/infermeta/spmd_rules/unsqueeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/infermeta/spmd_rules/unsqueeze.h"
#include <algorithm>
#include <numeric>

#include "glog/logging.h"
Expand Down Expand Up @@ -59,9 +58,7 @@ SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x,
}
}

std::sort(axis_copy.begin(), axis_copy.end(), UnsqueezeCmp);

for (int64_t i = static_cast<int64_t>(axis_copy.size()) - 1; i >= 0; i--) {
for (int64_t i = 0, n = static_cast<int64_t>(axis_copy.size()); i < n; i++) {
tgt_shape.emplace(tgt_shape.begin() + axis_copy[i], 1);
}

Expand Down
279 changes: 270 additions & 9 deletions test/auto_parallel/spmd_rules/test_squeeze_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,293 @@ class TestSqueezeSPMDRule(unittest.TestCase):
def setUp(self):
self.rule = core.get_phi_spmd_rule("squeeze")

x_shape = [1, 4, 1, 16]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
x_shape = [1, 8, 1, 16]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])

x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.attrs = OrderedDict()

def test_squeeze_infer_forward(self):
# shape: [1, 4, 1, 16] --> [1, 4, 16]
# dims_mapping: [0, 1, -1, -1] --> [0, 1, -1, -1] [0, 1, -1]
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
self.attrs = OrderedDict()
self.attrs['axis'] = [2]
# shape: [1, 8, 1, 16] --> [8, 16]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
self.attrs['axis'] = []
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])

# shape: [1, 8, 1, 16] --> [8, 16]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
self.attrs['axis'] = [0, 2]
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])

# shape: [1, 8, 1, 16] --> [1, 8, 16]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [-1, 0, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
self.attrs['axis'] = [1, 2]
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1])

# shape: [1, 8, 1, 16] --> [8, 1, 16]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, -1, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
self.attrs['axis'] = [-4]
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1])

# shape: [1, 8, 1, 16] --> [8, 16]
# dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0]
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0])
self.attrs['axis'] = []
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])

# shape: [1, 8, 1, 16] --> [8, 16]
# dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0]
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0])
self.attrs['axis'] = [0, 2]
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])

# shape: [1, 8, 1, 16] --> [1, 8, 16]
# dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [-1, 1, 0]
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0])
self.attrs['axis'] = [1, 2]
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0])

# shape: [1, 8, 1, 16] --> [8, 1, 16]
# dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, -1, 0]
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0])
self.attrs['axis'] = [-4]
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0])

def test_squeeze_infer_backward(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])

output_tensor_dist_attr = TensorDistAttr()
output_tensor_dist_attr.dims_mapping = [-1, -1]
output_tensor_dist_attr.process_mesh = process_mesh
self.output_dist_tensor_spec = DistTensorSpec(
[8, 16], output_tensor_dist_attr
)

# shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 16]
self.output_dist_tensor_spec.set_dims_mapping([0, 1])
self.attrs['axis'] = []
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])

# shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 16]
self.output_dist_tensor_spec.set_dims_mapping([0, 1])
self.attrs['axis'] = [0, 2]
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])

# shape: [1, 8, 1, 16] --> [1, 8, 16] (input --> output)
# dims_mapping: [-1, 0, 1] --> [-1, 0, -1, 1], [-1, 0, 1] (output --> input, output)
self.output_dist_tensor_spec.shape = [1, 8, 16]
self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1])
self.attrs['axis'] = [1, 2]
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1])

# shape: [1, 8, 1, 16] --> [8, 1, 16] (input --> output)
# dims_mapping: [0, -1, 1] --> [-1, 0, -1, 1], [0, -1, 1] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 1, 16]
self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1])
self.attrs['axis'] = [-4]
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1])

# shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 16]
self.output_dist_tensor_spec.set_dims_mapping([1, 0])
self.attrs['axis'] = []
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])

# shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 16]
self.output_dist_tensor_spec.set_dims_mapping([1, 0])
self.attrs['axis'] = [0, 2]
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])

# shape: [1, 8, 1, 16] --> [1, 8, 16] (input --> output)
# dims_mapping: [-1, 1, 0] --> [-1, 1, -1, 0], [-1, 1, 0] (output --> input, output)
self.output_dist_tensor_spec.shape = [1, 8, 16]
self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0])
self.attrs['axis'] = [1, 2]
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0])

# shape: [1, 8, 1, 16] --> [8, 1, 16] (input --> output)
# dims_mapping: [1, -1, 0] --> [-1, 1, -1, 0], [1, -1, 0] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 1, 16]
self.output_dist_tensor_spec.set_dims_mapping([1, -1, 0])
self.attrs['axis'] = [-4]
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0])


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 6c2f23f

Please sign in to comment.