Skip to content

Commit 5a8cf9e

Browse files
authored
fix moe_dispatch infermeta (#71140)
1 parent d2a6b77 commit 5a8cf9e

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

paddle/phi/infermeta/multiary.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6030,20 +6030,26 @@ void MoeDispatchInferMeta(const MetaTensor& X,
60306030
MetaTensor* permute_indices_per_token,
60316031
MetaTensor* expert_scales_float,
60326032
MetaTensor* top_k_indices) {
6033-
int token_rows = 0;
6033+
int token_rows = -1;
60346034
auto input_dims = X.dims();
6035+
auto gating_dims = gating_output.dims();
60356036
if (input_dims.size() == 3) {
60366037
token_rows = input_dims[0] * input_dims[1];
60376038
} else {
60386039
token_rows = input_dims[0];
60396040
}
6041+
const int expert_num = gating_dims[gating_dims.size() - 1];
60406042
const int num_rows = token_rows;
60416043
const int hidden_size = X.dims()[input_dims.size() - 1];
60426044

60436045
permute_input->set_dims({moe_topk * num_rows, hidden_size});
60446046
permute_input->set_dtype(X.dtype());
60456047
permute_input->set_layout(X.layout());
60466048

6049+
permute_indices_per_token->set_dims({expert_num});
6050+
token_nums_per_expert->set_dtype(DataType::INT64);
6051+
token_nums_per_expert->set_layout(X.layout());
6052+
60476053
permute_indices_per_token->set_dims({moe_topk, num_rows});
60486054
permute_indices_per_token->set_dtype(DataType::INT32);
60496055
permute_indices_per_token->set_layout(X.layout());

0 commit comments

Comments
 (0)