Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,42 @@ bool MaxOpInferSymbolicShape(pir::Operation *op,

return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all);
}
bool ModeOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_shape = x_shape_or_data.shape();

int axis = op->attribute<pir::Int32Attribute>("axis").data();
bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data();

int dim_size = x_shape.size();

if (axis < 0) {
axis += dim_size;
}

std::vector<symbol::DimExpr> out_dims;
for (int i = 0; i < axis; i++) {
out_dims.emplace_back(x_shape[i]);
}
if (keepdim && dim_size > 0) {
out_dims.emplace_back(symbol::DimExpr(1));
}
for (int i = axis + 1; i < dim_size; i++) {
out_dims.emplace_back(x_shape[i]);
}

symbol::TensorShapeOrDataDimExprs out_shape(out_dims);

infer_context->SetShapeOrDataForValue(op->result(0),
symbol::ShapeOrDataDimExprs{out_shape});

infer_context->SetShapeOrDataForValue(op->result(1),
symbol::ShapeOrDataDimExprs{out_shape});

return true;
}

bool MaxoutOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LpPool2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mode)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maxout)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3231,6 +3231,7 @@
kernel :
func : mode
backward : mode_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : momentum_
args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0f, bool multi_precision = false, float rescale_grad = 1.0f)
Expand Down