File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -6030,20 +6030,26 @@ void MoeDispatchInferMeta(const MetaTensor& X,
6030
6030
MetaTensor* permute_indices_per_token,
6031
6031
MetaTensor* expert_scales_float,
6032
6032
MetaTensor* top_k_indices) {
6033
- int token_rows = 0 ;
6033
+ int token_rows = - 1 ;
6034
6034
auto input_dims = X.dims ();
6035
+ auto gating_dims = gating_output.dims ();
6035
6036
if (input_dims.size () == 3 ) {
6036
6037
token_rows = input_dims[0 ] * input_dims[1 ];
6037
6038
} else {
6038
6039
token_rows = input_dims[0 ];
6039
6040
}
6041
+ const int expert_num = gating_dims[gating_dims.size () - 1 ];
6040
6042
const int num_rows = token_rows;
6041
6043
const int hidden_size = X.dims ()[input_dims.size () - 1 ];
6042
6044
6043
6045
permute_input->set_dims ({moe_topk * num_rows, hidden_size});
6044
6046
permute_input->set_dtype (X.dtype ());
6045
6047
permute_input->set_layout (X.layout ());
6046
6048
6049
+ permute_indices_per_token->set_dims ({expert_num});
6050
+ token_nums_per_expert->set_dtype (DataType::INT64);
6051
+ token_nums_per_expert->set_layout (X.layout ());
6052
+
6047
6053
permute_indices_per_token->set_dims ({moe_topk, num_rows});
6048
6054
permute_indices_per_token->set_dtype (DataType::INT32);
6049
6055
permute_indices_per_token->set_layout (X.layout ());
You can’t perform that action at this time.
0 commit comments