diff --git a/paddle/phi/infermeta/spmd_rules/unsqueeze.cc b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc index 463c121cdbb760..e75b09797732c8 100644 --- a/paddle/phi/infermeta/spmd_rules/unsqueeze.cc +++ b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" -#include #include #include "glog/logging.h" @@ -59,9 +58,7 @@ SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, } } - std::sort(axis_copy.begin(), axis_copy.end(), UnsqueezeCmp); - - for (int64_t i = static_cast(axis_copy.size()) - 1; i >= 0; i--) { + for (int64_t i = 0, n = static_cast(axis_copy.size()); i < n; i++) { tgt_shape.emplace(tgt_shape.begin() + axis_copy[i], 1); } diff --git a/test/auto_parallel/spmd_rules/test_squeeze_rule.py b/test/auto_parallel/spmd_rules/test_squeeze_rule.py index 8e74316d220b6d..d19e6a086a0ee8 100644 --- a/test/auto_parallel/spmd_rules/test_squeeze_rule.py +++ b/test/auto_parallel/spmd_rules/test_squeeze_rule.py @@ -27,32 +27,293 @@ class TestSqueezeSPMDRule(unittest.TestCase): def setUp(self): self.rule = core.get_phi_spmd_rule("squeeze") - x_shape = [1, 4, 1, 16] - process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + x_shape = [1, 8, 1, 16] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) x_tensor_dist_attr = TensorDistAttr() x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] x_tensor_dist_attr.process_mesh = process_mesh self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() def test_squeeze_infer_forward(self): - # shape: [1, 4, 1, 16] --> [1, 4, 16] - # dims_mapping: [0, 1, -1, -1] --> [0, 1, -1, -1] [0, 1, -1] - self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) - self.attrs = OrderedDict() - self.attrs['axis'] = [2] + # shape: [1, 8, 1, 16] --> [8, 16] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['axis'] = [] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + # shape: [1, 8, 1, 16] --> [8, 16] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['axis'] = [0, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + # shape: [1, 8, 1, 16] --> [1, 8, 16] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [-1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['axis'] = [-4] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # shape: [1, 8, 1, 16] --> [8, 16] + # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0]) + self.attrs['axis'] = [] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 16] + # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0]) + self.attrs['axis'] = [0, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + + # shape: [1, 8, 1, 16] --> [1, 8, 16] + # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [-1, 1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] + # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0]) + self.attrs['axis'] = [-4] result_dist_attrs = self.rule.infer_forward( self.x_dist_tensor_spec, self.attrs['axis'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]) + + def test_squeeze_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 16], output_tensor_dist_attr + ) + + # shape: [1, 8, 1, 16] --> [8, 16] (input --> output) + # dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + self.assertEqual(len(infered_input_dist_attrs), 1) self.assertEqual(len(infered_output_dist_attrs), 1) self.assertEqual( - infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + # shape: [1, 8, 1, 16] --> [8, 16] (input --> output) + # dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [0, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + # shape: [1, 8, 1, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 0, 1] --> [-1, 0, -1, 1], [-1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] (input --> output) + # dims_mapping: [0, -1, 1] --> [-1, 0, -1, 1], [0, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + self.attrs['axis'] = [-4] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # shape: [1, 8, 1, 16] --> [8, 16] (input --> output) + # dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 16] (input --> output) + # dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [0, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + + # shape: [1, 8, 1, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 1, 0] --> [-1, 1, -1, 0], [-1, 1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] (input --> output) + # dims_mapping: [1, -1, 0] --> [-1, 1, -1, 0], [1, -1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, -1, 0]) + self.attrs['axis'] = [-4] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] ) - self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]) if __name__ == "__main__": diff --git a/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py b/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py index 075538423c21a2..e9b50f90135735 100644 --- a/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py +++ b/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py @@ -27,31 +27,278 @@ class TestUnsqueezeSPMDRule(unittest.TestCase): def setUp(self): self.rule = core.get_phi_spmd_rule("unsqueeze") - x_shape = [4, 16] - process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + x_shape = [8, 16] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) x_tensor_dist_attr = TensorDistAttr() x_tensor_dist_attr.dims_mapping = [-1, -1] x_tensor_dist_attr.process_mesh = process_mesh self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() def test_unsqueeze_infer_forward(self): - # shape: [4, 16] --> [1, 4, 1, 16] - # dims_mapping: [0, 1] --> [0, 1] [-1, 0, -1, 1] + # shape: [8, 16] --> [1, 8, 16] + # dims_mapping: [0, 1] --> [0, 1] [-1, 0, 1] self.x_dist_tensor_spec.set_dims_mapping([0, 1]) - self.attrs = OrderedDict() - self.attrs['axis'] = [0, 1] + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16] --> [8, 16, 1] + # dims_mapping: [0, 1] --> [0, 1] [0, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] + # dims_mapping: [0, 1] --> [0, 1] [0, -1, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] + # dims_mapping: [0, 1] --> [0, 1] [-1, -1, -1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 0, 1] + ) + + # shape: [8, 16] --> [1, 8, 16] + # dims_mapping: [1, 0] --> [1, 0] [-1, 1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [8, 16] --> [8, 16, 1] + # dims_mapping: [1, 0] --> [1, 0] [1, 0, -1] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] + # dims_mapping: [1, 0] --> [1, 0] [1, -1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] + # dims_mapping: [1, 0] --> [1, 0] [-1, -1, -1, 1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [0, 1, 2] result_dist_attrs = self.rule.infer_forward( self.x_dist_tensor_spec, self.attrs['axis'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1, 0] + ) + + def test_unsqueeze_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 16], output_tensor_dist_attr + ) + + # shape: [8, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 0, 1] --> [0, 1], [-1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + self.assertEqual(len(infered_input_dist_attrs), 1) self.assertEqual(len(infered_output_dist_attrs), 1) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16] --> [8, 16, 1] (input --> output) + # dims_mapping: [0, 1, -1] --> [0, 1], [0, 1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] (input --> output) + # dims_mapping: [0, -1, -1, 1] --> [0, 1], [0, -1, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] (input --> output) + # dims_mapping: [-1, -1, -1, 0, 1] --> [0, 1], [-1, -1, -1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, 0, 1]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 0, 1] + ) + + # shape: [8, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 1, 0] --> [1, 0], [-1, 1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [8, 16] --> [8, 16, 1] (input --> output) + # dims_mapping: [1, 0, -1] --> [1, 0], [1, 0, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, 0, -1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] (input --> output) + # dims_mapping: [1, -1, -1, 0] --> [1, 0], [1, -1, -1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] (input --> output) + # dims_mapping: [-1, -1, -1, 1, 0] --> [1, 0], [-1, -1, -1, 1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, 1, 0]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) self.assertEqual( - infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1, 0] )