Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ namespace Torch {
int64_t toPositiveDim(int64_t dim, int64_t inputRank);
bool isValidDim(int64_t dim, int64_t inputRank);
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
/// Returns the index indicated by `v` for a list of given `length`.
/// If the index is negative, it is adjusted to `length` + `v`.
/// `None` is returned the index is not an integer in the range [0,`length).
llvm::Optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
int64_t length);
torch_upstream::ScalarType getScalarTypeForType(Type type);
// Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
Expand Down
24 changes: 9 additions & 15 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,18 +725,14 @@ OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
if (!type || !type.hasSizes())
return nullptr;

int64_t inputRank = type.getSizes().size();
int64_t dim;
if (!matchPattern(this->dim(), m_TorchConstantInt(&dim)))
llvm::Optional<int64_t> dimOpt = matchLegalConstantIndexIntoListOfSize(
this->dim(), type.getSizes().size());
if (!dimOpt)
return nullptr;
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return nullptr;

if (type.getSizes()[dim] == kUnknownSize)
if (type.getSizes()[*dimOpt] == kUnknownSize)
return nullptr;
return IntegerAttr::get(IntegerType::get(getContext(), 64),
type.getSizes()[dim]);
type.getSizes()[*dimOpt]);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1227,14 +1223,12 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
return failure();

// Get the index, but be careful because it might be statically invalid.
int64_t index;
if (!matchPattern(op.getOperand(1), m_TorchConstantInt(&index)))
return failure();
int64_t positiveDim = toPositiveDim(index, listConstruct.getNumOperands());
if (!isValidDim(positiveDim, listConstruct.getNumOperands()))
llvm::Optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
op.getOperand(1), listConstruct.getNumOperands());
if (!indexOpt)
return rewriter.notifyMatchFailure(op, "statically invalid index");

rewriter.replaceOp(op, {listConstruct.getOperand(positiveDim)});
rewriter.replaceOp(op, {listConstruct.getOperand(*indexOpt)});
return success();
});
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
Expand Down
21 changes: 10 additions & 11 deletions lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,14 @@ class AbstractlyInterpretListOpsWithinABlock
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
if (!setItem.use_empty())
return failure();
int64_t index;
if (!matchPattern(setItem.idx(), m_TorchConstantInt(&index)))
return failure();
llvm::Optional<int64_t> indexOpt =
matchLegalConstantIndexIntoListOfSize(setItem.idx(),
runningList.size());
// The index might be statically out of bounds.
if (index < 0 || index >= static_cast<int64_t>(runningList.size()))
if (!indexOpt)
return failure();
if (setItem.l() == op) {
runningList[index] = setItem.el();
runningList[*indexOpt] = setItem.el();
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
Expand Down Expand Up @@ -293,15 +293,14 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
// change the size of the list. It might clobber some elements, which then
// become dimensions with unknown size.
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
int64_t index;
// If the index is statically known, we can clobber only a single index.
// Otherwise, we conservatively clobber all of them.
if (matchPattern(setItem.idx(), m_TorchConstantInt(&index)) &&
isValidDim(index, listConstruct->getNumOperands())) {
clobberedElements.set(index);
} else {
llvm::Optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
setItem.idx(), listConstruct->getNumOperands());
if (indexOpt)
clobberedElements.set(*indexOpt);
else
clobberedElements.set();
}
continue;
}
// An unhandled op! We can't make any assumptions about the shape.
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ bool Torch::isValidDim(int64_t dim, int64_t inputRank) {
return dim >= 0 && dim < inputRank;
}

llvm::Optional<int64_t>
Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) {
int64_t dim;
if (!matchPattern(v, m_TorchConstantInt(&dim)))
return llvm::None;
dim = toPositiveDim(dim, length);
if (!isValidDim(dim, length))
return llvm::None;
return dim;
}

bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
if (!listConstruct)
Expand Down
19 changes: 19 additions & 0 deletions test/Dialect/Torch/simplify-shape-calculations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,25 @@ func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !
return %0 : !torch.vtensor
}

// Test negative indexes with set_item op.
// CHECK-LABEL: func @abstractly_interpret_list_ops$neg_index_set_item(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%int-2 = torch.constant.int -2
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten._set_item.t %1, %int-1, %arg2 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
%3 = torch.aten._set_item.t %1, %int-2, %arg1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}

// Test interspersed mutation and evaluation ops.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
Expand Down