Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
WintersMontagne10335 committed Oct 5, 2023
1 parent 4aff5eb commit 9e9140f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ PD_REGISTER_SPMD_RULE(
PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd),
PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse));

// default_data_parallel rule
PD_REGISTER_SPMD_RULE(
default_data_parallel,
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd),
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse));

// unsqueeze rule
PD_REGISTER_SPMD_RULE(
unsqueeze,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestDefaultDataParallelSPMDRule(unittest.TestCase):
def setUp(self):
# After replaced all spmd rules by phi impl, we can recover the
# api name to `get_spmd_rule`
self.rule = core.get_phi_spmd_rule("unsqueeze")
self.rule = core.get_phi_spmd_rule("default_data_parallel")

x_shape = [10, 10, 32, 48]
y_shape = [32, 48]
Expand Down

0 comments on commit 9e9140f

Please sign in to comment.