Skip to content

Commit

Permalink
Fix empty tensor when select -1 (llvm#1787)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Jan 17, 2023
1 parent 19bb8ae commit e269843
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 2 deletions.
3 changes: 3 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SliceSingleIdxModule_basic",
"SqueezeDimModule_dynamic",
"SqueezeDimModule_negDim",
Expand Down Expand Up @@ -454,6 +455,7 @@
"BoolTensorReturnMixedModule_basic",
"BoolTensorHandleSignless_basic",
"ElementwiseRsqrtModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SqueezeModule_static",
"SqueezeModule_noUnitDim",
"SqueezeModule_allUnitDim",
Expand Down Expand Up @@ -662,6 +664,7 @@
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"BernoulliFloatModule_basic",
"BernoulliTensorModule_basic",
"BincountMinlengthModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9111,6 +9111,30 @@ def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
let hasFolder = 1;
}

def Torch_AtenIntBoolOp : Torch_Op<"aten.Int.bool", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::Int.bool : (bool) -> (int)`";
let arguments = (ins
Torch_BoolType:$a
);
let results = (outs
Torch_IntType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIntBoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIntBoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,18 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenIntBoolOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenIntBoolOp::fold(ArrayRef<Attribute> operands) {
bool b;
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
return getI64IntegerAttr(getContext(), static_cast<long>(b));
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,15 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
Value dim = op.getDim();
Value self = op.getSelf();

// convert `start` to non-negative: start += int(start < 0) * dimSize
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value isNegative = rewriter.create<AtenLtIntOp>(loc, start, zero);
isNegative = rewriter.create<AtenIntBoolOp>(loc, isNegative);
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
Value indexOffset = rewriter.create<AtenMulIntOp>(loc, isNegative, dimSize);
start = rewriter.create<AtenAddIntOp>(loc, start, indexOffset);

Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::Float.str : (str) -> (float)")
emit("aten::Int.float : (float) -> (int)")
emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True)
emit("aten::Int.bool : (bool) -> (int)", has_folder=True)

# Primitive ops
emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True)
Expand Down
53 changes: 53 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,59 @@ def BoolIntConstantModule_basic(module, tu: TestUtils):

# ==============================================================================

class AtenIntBoolOpModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.bool, True),
])
def forward(self, x):
return int(torch.ops.aten.Int(x))


@register_test_case(module_factory=lambda: AtenIntBoolOpModule())
def AtenIntBoolOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=0, high=2).bool())


class AtenIntBoolOpConstTrueModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
return int(torch.ops.aten.Int(True))


@register_test_case(module_factory=lambda: AtenIntBoolOpConstTrueModule())
def AtenIntBoolOpConstTrueModule_basic(module, tu: TestUtils):
module.forward()


class AtenIntBoolOpConstFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
return int(torch.ops.aten.Int(False))


@register_test_case(module_factory=lambda: AtenIntBoolOpConstFalseModule())
def AtenIntBoolOpConstFalseModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

class AtenIntTensorByteDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
22 changes: 20 additions & 2 deletions python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,30 @@ def __init__(self):
([-1, -1], torch.int64, True),
])
def forward(self, x):
return x.select(0,0)
return torch.select(x, dim=0, index=0)


@register_test_case(module_factory=lambda: SelectIntModule())
def SelectIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(5,5, high=10))
module.forward(tu.randint(5, 5, high=10))


class SelectIntNegativeDimAndIndexStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([5, 5], torch.int64, True),
])
def forward(self, x):
return torch.select(x, dim=-1, index=-1)


@register_test_case(module_factory=lambda: SelectIntNegativeDimAndIndexStaticModule())
def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(5, 5, high=10))

# ==============================================================================

Expand Down

0 comments on commit e269843

Please sign in to comment.