Skip to content

Commit e64c63b

Browse files
authored
Merge pull request llvm#41 from clang-ykt/infer-conv
Infer shape for ConvNoBias operation.
2 parents 51b0f4c + 7dda698 commit e64c63b

File tree

5 files changed

+321
-5
lines changed

5 files changed

+321
-5
lines changed

src/dialect/onnx/onnx.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,16 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
104104
}
105105

106106
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
107-
[NoSideEffect]> {
107+
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
108108
let summary = "ONNX Conv operation with no Bias operand.";
109109
let description = [{
110110
"The convolution operator consumes an input tensor and a filter, and"
111111
"computes the output."
112112
}];
113113
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
114114
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
115+
116+
let verifier = [{ return ::verify(*this); }];
115117
}
116118

117119
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",

src/dialect/onnx/onnx_ops.cpp

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ void ONNXReshapeOp::inferShapes() {
412412
void ONNXTransposeOp::inferShapes() {
413413
// Cannot infer shape if no shape exists.
414414
if (!getOperand().getType().isa<RankedTensorType>())
415-
emitError("Shape tensor not ranked.");
415+
return;
416416

417417
// Naive transposition which handles the default case of
418418
// reversing the shape of the tensor (similar to numpy.transpose).
@@ -448,6 +448,181 @@ LogicalResult verify(ONNXTransposeOp op) {
448448
return success();
449449
}
450450

451+
//===----------------------------------------------------------------------===//
452+
453+
// Conv
454+
455+
// For this operation, we define the attributes once in the original Conv
456+
// operation class. There is no need to redefine the attribute names for the
457+
// other classes based on Conv.
458+
void ONNXConvNoBiasOp::inferShapes() {
459+
// Generic shape for data input X and weight tensor W:
460+
// X: (N x C x D1 x D2 ... x Dn)
461+
// W: (M x C/group x k1 x k2 x ... x kn)
462+
463+
// Cannot infer shape if no shape exists.
464+
if (!getOperand(0).getType().isa<RankedTensorType>() ||
465+
!getOperand(1).getType().isa<RankedTensorType>())
466+
return;
467+
468+
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
469+
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
470+
auto dataShape = dataTy.getShape();
471+
auto weightShape = weightTy.getShape();
472+
473+
// Check that shape of weight and data have same length.
474+
if (dataShape.size() != weightShape.size())
475+
emitError("Weight size not compatible with data size.");
476+
477+
// Required attribute auto_pad defaults to NOTSET.
478+
auto autoPad = getAttrOfType<StringAttr>(
479+
ONNXConvOp::getAutoPadAttrName()).getValue();
480+
// Group is a required attribute and should have default value of 1.
481+
int64_t group = getAttrOfType<IntegerAttr>(
482+
ONNXConvOp::getGroupAttrName()).getInt();
483+
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
484+
if (dataShape[1] != (weightShape[1] * group))
485+
emitError("Channel dimension mismatch.");
486+
487+
// Note: the value of the group attribut only impacts the way the
488+
// computation is carried out and not the actual output size.
489+
490+
// First two output dimensions consist of the number of batches and the
491+
// number of kernels being applied.
492+
//
493+
SmallVector<int64_t, 2> dims;
494+
// Insert batch size.
495+
dims.emplace_back(dataShape[0]);
496+
// Insert number of filters being applied (number of output channels).
497+
dims.emplace_back(weightShape[0]);
498+
499+
// Spatial dimensions of the output are computed using the formula:
500+
//
501+
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
502+
//
503+
SmallVector<int64_t, 2> outSpatialDims;
504+
// Number of spatial dimensions.
505+
int32_t nDims = dataShape.size() - 2;
506+
507+
// Initialize dimenions based on the input spatial dimensions.
508+
for (int i = 2; i < dataShape.size(); ++i)
509+
outSpatialDims.emplace_back(dataShape[i]);
510+
511+
// Use kernel_shape attribute if present otherwise use size from weight
512+
// argument.
513+
SmallVector<int64_t, 2> kernelDims;
514+
if (auto kernelShape = getAttrOfType<ArrayAttr>(
515+
ONNXConvOp::getKernelShapeAttrName())) {
516+
if (kernelShape.getValue().size() != nDims)
517+
emitError("kernel_shape length incompatible with spatial dimensions.");
518+
for (int i = 0; i < nDims; ++i)
519+
kernelDims.emplace_back(
520+
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
521+
} else {
522+
for (int i = 0; i < nDims; ++i)
523+
kernelDims.emplace_back(weightShape[i + 2]);
524+
}
525+
526+
// Check if dilations attribute is present.
527+
// If it is then compute new kernel size that includes the receptive field.
528+
// In this calculation we assume that the receptive field pixels must all be
529+
// within the bounds of the image. In this case the new kernel size is given
530+
// by:
531+
//
532+
// ( K + 1 ) * d - 1
533+
// where K is a kernel dimension and d is the dilation along that axis.
534+
//
535+
// From a dimensionality perspective the kernel size becomes the dilated
536+
// kernel size.
537+
if (auto dilations = getAttrOfType<ArrayAttr>(
538+
ONNXConvOp::getDilationsAttrName())) {
539+
if (dilations.getValue().size() != nDims)
540+
emitError("dilations length incompatible with spatial dimensions.");
541+
for (int i = 0; i < nDims; ++i)
542+
kernelDims[i] = (kernelDims[i] + 1) *
543+
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1;
544+
}
545+
546+
// Subtract kernel dimensions from input data dimensions.
547+
for (int i = 0; i < nDims; ++i)
548+
outSpatialDims[i] -= kernelDims[i];
549+
550+
// Add padding information.
551+
if (autoPad == "NOTSET") {
552+
// Use pads to to determine the padding. If attribute is not
553+
// present then pads is considered to be all zeros (no padding).
554+
if (auto pads = getAttrOfType<ArrayAttr>(
555+
ONNXConvOp::getPadsAttrName())) {
556+
// pads consists of two entries for each spatial axis.
557+
if (pads.getValue().size() != 2 * nDims)
558+
emitError("pads size is not twice the spatial size.");
559+
560+
for (int i = 0; i < nDims; ++i) {
561+
// Padding for beginning of axis.
562+
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
563+
outSpatialDims[i] += p;
564+
// Padding for end of axis.
565+
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
566+
outSpatialDims[i] += p;
567+
}
568+
}
569+
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
570+
// Pad input so that output size matches input size.
571+
// Each spatial dimension needs to be padded by a total of:
572+
//
573+
// K - 1
574+
//
575+
// where K is a kernel spatial dimension.
576+
// Pad as if stride is 1.
577+
for (int i = 0; i < nDims; ++i)
578+
outSpatialDims[i] += kernelDims[i] - 1;
579+
} else if (autoPad == "VALID") {
580+
// No padding
581+
} else {
582+
emitError("Unexpected attribute value for auto_pad.");
583+
}
584+
585+
// Strides
586+
if (auto strides = getAttrOfType<ArrayAttr>(
587+
ONNXConvOp::getStridesAttrName())) {
588+
if (strides.getValue().size() != nDims)
589+
emitError("strides length incompatible with spatial dimensions.");
590+
for (int i = 0; i < nDims; ++i) {
591+
int64_t stride =
592+
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
593+
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
594+
}
595+
}
596+
597+
for (int i = 0; i < nDims; ++i)
598+
outSpatialDims[i] += 1;
599+
600+
dims.append(outSpatialDims.begin(), outSpatialDims.end());
601+
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
602+
}
603+
604+
LogicalResult verify(ONNXConvNoBiasOp op) {
605+
auto module = op.getParentOfType<ModuleOp>();
606+
if (!module)
607+
op.emitError("expected to belong to a module");
608+
609+
auto autoPadAttr = op.getAttrOfType<StringAttr>(
610+
ONNXConvOp::getAutoPadAttrName());
611+
if (!autoPadAttr)
612+
op.emitError("auto_pad attribute not specified.");
613+
if (autoPadAttr.getValue() != "NOTSET")
614+
if (auto pads = op.getAttrOfType<ArrayAttr>(
615+
ONNXConvOp::getPadsAttrName()))
616+
op.emitError("auto_pad and pads are both set.");
617+
618+
auto groupAttr =
619+
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
620+
if (!groupAttr)
621+
op.emitError("group attribute not specified.");
622+
623+
return success();
624+
}
625+
451626
//===----------------------------------------------------------------------===//
452627
// TableGen'd op method definitions
453628
//===----------------------------------------------------------------------===//

src/dialect/onnx/onnxop.inc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,15 @@ def ONNXConvOp:ONNX_Op<"Conv",
324324
}];
325325
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
326326
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
327+
328+
let extraClassDeclaration = [{
329+
static StringRef getAutoPadAttrName() { return "auto_pad"; }
330+
static StringRef getDilationsAttrName() { return "dilations"; }
331+
static StringRef getGroupAttrName() { return "group"; }
332+
static StringRef getKernelShapeAttrName() { return "kernel_shape"; }
333+
static StringRef getPadsAttrName() { return "pads"; }
334+
static StringRef getStridesAttrName() { return "strides"; }
335+
}];
327336
}
328337

329338
def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",

src/pass/shape_inference_pass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
117117
op->getName().getStringRef() != "onnx.GemmNoBias" &&
118118
op->getName().getStringRef() != "onnx.Reshape" &&
119119
op->getName().getStringRef() != "onnx.Transpose" &&
120-
op->getName().getStringRef() != "onnx.Softmax")
120+
op->getName().getStringRef() != "onnx.Softmax" &&
121+
op->getName().getStringRef() != "onnx.ConvNoBias")
121122
return false;
122123
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
123124
return !result_type.isa<RankedTensorType>();

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
// RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s
22

3+
//===----------------------------------------------------------------------===//
34
/// Test the default behavior of transpose when no information for the
4-
/// permutation of the axes is provided.
5+
/// permutation of the axes is provided and when a permutation is provided.
6+
//===----------------------------------------------------------------------===//
7+
58
func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
69
%0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
710
"std.return"(%0) : (tensor<*xf32>) -> ()
@@ -12,11 +15,137 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
1215
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
1316

1417
/// Test shape inference for transposition when perm attribute is specified.
18+
1519
func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
1620
%0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
1721
"std.return"(%0) : (tensor<*xf32>) -> ()
1822
}
1923

2024
// CHECK-LABEL: test_transpose
2125
// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32>
22-
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32>
26+
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32>
27+
28+
//===----------------------------------------------------------------------===//
29+
/// Test shape inference for ConvNoBias operation and all its attributes.
30+
//===----------------------------------------------------------------------===//
31+
32+
/// Default and required attributes.
33+
34+
func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
35+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
36+
"std.return"(%0) : (tensor<*xf32>) -> ()
37+
}
38+
39+
// CHECK-LABEL: test_conv_no_bias_1
40+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32>
41+
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32>
42+
43+
/// kernel_shape attribute.
44+
45+
func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
46+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
47+
"std.return"(%0) : (tensor<*xf32>) -> ()
48+
}
49+
50+
// CHECK-LABEL: test_conv_no_bias_2
51+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32>
52+
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32>
53+
54+
/// pads attribute.
55+
/// Use pads to make output size equal to input size by adding K - 1 to the result.
56+
57+
func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
58+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
59+
"std.return"(%0) : (tensor<*xf32>) -> ()
60+
}
61+
62+
// CHECK-LABEL: test_conv_no_bias_3
63+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
64+
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
65+
66+
/// auto_pad set to SAME_UPPER and SAME_LOWER.
67+
68+
func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
69+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
70+
"std.return"(%0) : (tensor<*xf32>) -> ()
71+
}
72+
73+
// CHECK-LABEL: test_conv_no_bias_4
74+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
75+
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
76+
77+
func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
78+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
79+
"std.return"(%0) : (tensor<*xf32>) -> ()
80+
}
81+
82+
// CHECK-LABEL: test_conv_no_bias_5
83+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
84+
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
85+
86+
/// auto_pad set to VALID.
87+
88+
func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
89+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
90+
"std.return"(%0) : (tensor<*xf32>) -> ()
91+
}
92+
93+
// CHECK-LABEL: test_conv_no_bias_6
94+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32>
95+
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32>
96+
97+
/// With strides attribute.
98+
99+
func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
100+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
101+
"std.return"(%0) : (tensor<*xf32>) -> ()
102+
}
103+
104+
// CHECK-LABEL: test_conv_no_bias_7
105+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32>
106+
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32>
107+
108+
/// auto_pad set to SAME_UPPER with strides attribute.
109+
/// The auto_pad will pas as if stride is equal to 1.
110+
111+
func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
112+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
113+
"std.return"(%0) : (tensor<*xf32>) -> ()
114+
}
115+
116+
// CHECK-LABEL: test_conv_no_bias_8
117+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32>
118+
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32>
119+
120+
/// dilations attribute.
121+
122+
func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
123+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
124+
"std.return"(%0) : (tensor<*xf32>) -> ()
125+
}
126+
127+
// CHECK-LABEL: test_conv_no_bias_9
128+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x20x42xf32>
129+
// CHECK: return [[RES_ATTR]] : tensor<1x5x20x42xf32>
130+
131+
/// dilations attribute with stride.
132+
133+
func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
134+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
135+
"std.return"(%0) : (tensor<*xf32>) -> ()
136+
}
137+
138+
// CHECK-LABEL: test_conv_no_bias_10
139+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i32, strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x10x21xf32>
140+
// CHECK: return [[RES_ATTR]] : tensor<1x5x10x21xf32>
141+
142+
/// dilations attribute with auto_pad set to SAME_UPPER.
143+
144+
func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
145+
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
146+
"std.return"(%0) : (tensor<*xf32>) -> ()
147+
}
148+
149+
// CHECK-LABEL: test_conv_no_bias_11
150+
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", dilations = [2, 3], group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32>
151+
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>

0 commit comments

Comments
 (0)