Skip to content

Commit

Permalink
Address change requests
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrifunovic98 committed Mar 12, 2024
1 parent 1a99ebc commit 6356695
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 39 deletions.
69 changes: 32 additions & 37 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4287,21 +4287,35 @@ LogicalResult AtenLinalgCrossOp::verify() {
auto selfType = getSelf().getType().cast<BaseTensorType>();
auto otherType = getOther().getType().cast<BaseTensorType>();

if (!selfType.hasDtype() || !otherType.hasDtype()) {
if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() ||
!otherType.hasSizes()) {
return success();
}

Type selfDtype = selfType.getDtype();
Type otherDtype = otherType.getDtype();

auto selfRankedType = getSelf().getType().cast<RankedTensorType>();
auto otherRankedType = getOther().getType().cast<RankedTensorType>();
// the operation succeeds only if both inputs have the same dtype
if (selfDtype != otherDtype) {
return emitOpError("input tensors must have the same dtype, but got ")
<< selfDtype << " and " << otherDtype;
}

auto selfShape = selfRankedType.getShape();
auto otherShape = otherRankedType.getShape();
// Check if any of the input tensors has torch.bool dtype.
// The operation does not support this type.
// The docs state that only float, double, cfloat and cdouble dtypes are
// supported, but, when testing, it fails only for boolean dtype. Update to
// fit the docs if necessary.
// https://pytorch.org/docs/stable/generated/torch.linalg.cross.html
if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) {
return emitOpError("input tensors must not have bool dtype");
}

int64_t selfRank = selfRankedType.getRank();
int64_t otherRank = otherRankedType.getRank();
ArrayRef<int64_t> selfShape = selfType.getSizes();
ArrayRef<int64_t> otherShape = otherType.getSizes();

int64_t selfRank = selfShape.size();
int64_t otherRank = otherShape.size();

// check if both input tensors have the same number of dims
if (selfRank != otherRank) {
Expand All @@ -4310,12 +4324,13 @@ LogicalResult AtenLinalgCrossOp::verify() {
<< selfRank << " and " << otherRank;
}

// convert dim to an integer type
int64_t dim;
if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) {
return emitOpError("dim must be a constant int");
return success();
}

// check if is dim is in the correct range
// check if dim is in the correct range
if (dim >= selfRank || dim < -selfRank) {
return emitOpError("dim expected to be in rank of [")
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
Expand All @@ -4328,46 +4343,26 @@ LogicalResult AtenLinalgCrossOp::verify() {

// check if the size of the dimensions specified by 'dim' is equal to 3
// (required by the operation)
if (selfShape[dim] != 3 || otherShape[dim] != 3) {
if ((selfShape[dim] != 3 && selfShape[dim] != kUnknownSize) ||
(otherShape[dim] != 3 && otherShape[dim] != kUnknownSize)) {
return emitOpError("inputs dimension ")
<< dim << " must have length 3, but got " << selfShape[dim]
<< " and " << otherShape[dim];
}

// Check if any of the input tensors has torch.bool dtype.
// The operation does not support this type.
// The docs state that only float, double, cfloat and cdouble dtypes are
// supported, but, when testing, it fails only for boolean dtype. Update to
// fit the docs if necessary.
if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) {
return emitOpError("input tensors must not have bool dtype");
}

auto selfCurrent = selfShape.begin();
auto selfEnd = selfShape.end();
auto otherCurrent = otherShape.begin();
int32_t i = 0;

// Check if there is a disparity between dimension sizes.
// Dimensions at the same index must either have the same size,
// or one of them must be equal to 1.
while (selfCurrent != selfEnd) {
if (*selfCurrent != *otherCurrent && *selfCurrent != 1 &&
*otherCurrent != 1) {
int32_t i = 0;
for (auto [selfCurrent, otherCurrent] :
llvm::zip_equal(selfShape, otherShape)) {
if (selfCurrent != otherCurrent && selfCurrent != 1 && otherCurrent != 1) {
return emitOpError("the size of first tensor (")
<< *selfCurrent << ") must match the size of second tensor ("
<< *otherCurrent << ") at dimension " << i
<< selfCurrent << ") must match the size of second tensor ("
<< otherCurrent << ") at dimension " << i
<< " or one of them must be 1";
}
++i;
++selfCurrent;
++otherCurrent;
}

// the operation succeeds only if both inputs have the same dtype
if (selfDtype != otherDtype) {
return emitOpError("input tensors must have the same dtype, but got ")
<< selfDtype << " and " << otherDtype;
}

return success();
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1838,7 +1838,6 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern<AtenLinalgCrossOp> {

Type dtype = resType.getDtype();
if (dtype.isa<mlir::ComplexType>()) {
printf("Is a complex type\n");
return rewriter.notifyMatchFailure(
op, "lowering of aten.linalg_cross for complex inputs dtype is "
"currently unimplemented");
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,9 @@
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",

# Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1"
"AtenLinalgCrossDynamic_basic"
}

ONNX_CRASHING_SET = { }
Expand Down
22 changes: 21 additions & 1 deletion projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,4 +379,24 @@ def forward(self, a, b):

@register_test_case(module_factory=lambda: AtenLinalgCrossNegativeDim())
def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))

# ==============================================================================

class AtenLinalgCrossDynamic(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.ops.aten.linalg_cross(a, b, dim=1)


@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic())
def AtenLinalgCrossDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))

0 comments on commit 6356695

Please sign in to comment.