Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"MobilenetV3Module_basic",
"ConvolutionModule3D_basic",
"ConvolutionModule1D_basic",
"MaxPool2dWith3dInputModule_basic",
"MaxPool2dWithIndicesWith3dInputModule_basic",
}
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS

Expand Down
6 changes: 3 additions & 3 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy, Value initElem);

Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy);

Value castIntToIndex(OpBuilder &b, Location loc, Value v);

Value castIndexToInt(OpBuilder &b, Location loc, Value idx);
Expand All @@ -51,9 +54,6 @@ SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, Value tensor);

Value getTensorSize(OpBuilder &b, Location loc, Value tensor);

Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy);

// Creates a constant of type `elemType` with value `val`.
Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType);

Expand Down
29 changes: 29 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2938,6 +2938,35 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
}];
}

def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
ListOfTorchIntType:$kernel_size,
ListOfTorchIntType:$stride,
ListOfTorchIntType:$padding,
ListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool2dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void AtenMaxPool2dWithIndicesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}

def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
Value paddedInput = torch_to_linalg::getPaddedTensor(op, rewriter, input,
paddingIncludingNC);
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
op, rewriter, input, paddingIncludingNC);

SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
Expand Down
Loading