Skip to content

Commit 4d6e515

Browse files
authored
【BUAA】【Infer Symbolic Shape】Add depthwise_conv for CINN (#67814)
* init * rename
1 parent 8c5e841 commit 4d6e515

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -487,12 +487,10 @@ bool CrossOpInferSymbolicShape(pir::Operation *op,
487487
// return true;
488488
// }
489489

490-
// bool DepthwiseConvOpInferSymbolicShape(pir::Operation *op,
491-
// pir::InferSymbolicShapeContext
492-
// *infer_context) {
493-
// // pass
494-
// return true;
495-
// }
490+
bool DepthwiseConv2dOpInferSymbolicShape(
491+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
492+
return Conv2dOpInferSymbolicShape(op, infer_context);
493+
}
496494

497495
bool DotOpInferSymbolicShape(pir::Operation *op,
498496
pir::InferSymbolicShapeContext *infer_context) {

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d)
3636
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ConvTranspose)
3737
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cross)
3838
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Correlation)
39-
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(DepthwiseConv)
39+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(DepthwiseConv2d)
4040
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dot)
4141
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dropout)
4242
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,7 @@
12831283
func : depthwise_conv2d
12841284
data_type : input
12851285
backward : depthwise_conv2d_grad
1286-
# interfaces : paddle::dialect::InferSymbolicShapeInterface
1286+
interfaces : paddle::dialect::InferSymbolicShapeInterface
12871287

12881288
- op : depthwise_conv2d_transpose
12891289
args : (Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW")

0 commit comments

Comments
 (0)