Skip to content

Commit

Permalink
[MLIR][TORCH] Add support for conversion to int8 dtype
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
  • Loading branch information
vivekkhandelwal1 committed Oct 2, 2023
1 parent 71ac62f commit c434736
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 10 deletions.
3 changes: 3 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@

# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
"AtenEmbeddingBagStaticModule_basic",

# Lowering not present for this case
"ElementwiseToDtypeI64ToUI8Module_basic",
}

if torch_version_for_comparison() < version.parse("2.1.0.dev"):
Expand Down
3 changes: 2 additions & 1 deletion include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype = std::nullopt);
std::optional<Type> srcOriginalDtype = std::nullopt,
std::optional<Type> dstOriginalDtype = std::nullopt);

Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Expand Down
18 changes: 17 additions & 1 deletion lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(atenToDtype.getType())
.cast<RankedTensorType>()
.getElementType();
Value result = convertScalarToDtype(b, loc, input, dtype);
Type resultElementType;
int64_t dtypeInt;
if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) {
atenToDtype.emitError("unimplemented: dtype must be a constant integer");
return nullptr;
}
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
if (failed(maybeResultElementType)) {
atenToDtype.emitError("unable to convert `dtypeInt` to builtin type");
return nullptr;
}
resultElementType = *maybeResultElementType;
Value result = convertScalarToDtype(b, loc, input, dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
return result;
}
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
Expand Down
23 changes: 15 additions & 8 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype) {
std::optional<Type> srcOriginalDtype,
std::optional<Type> dstOriginalDtype) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
Expand All @@ -261,14 +262,20 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return false;
};

// We only support conversion from Byte or Char scalarType not to Byte or Char
// dtype.
// We don't support conversion to Byte dtype.
if (isByteOrChar(dtype)) {
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
"convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype
<< "(dtype)";
return nullptr;
if (!dstOriginalDtype.has_value()) {
mlir::emitError(loc)
<< "unimplemented: for conversion to byte or char type "
"dstOriginalDtype has to be passed to convertScalarToDtype";
return nullptr;
}
if (dstOriginalDtype->isUnsignedInteger()) {
mlir::emitError(loc)
<< "unsupported: conversion to byte type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
}
}

// If the dtype is i1, i.e., a boolean type.
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
}

# TODO: Delete once torch 2.1.0 is released
Expand Down
38 changes: 38 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,44 @@ def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseToDtypeI64ToI8Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1], torch.int64, True)])
def forward(self, x):
return x.to(torch.int8)


@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToI8Module())
def ElementwiseToDtypeI64ToI8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100))


# ==============================================================================


class ElementwiseToDtypeI64ToUI8Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1], torch.int64, True)])
def forward(self, x):
return x.to(torch.uint8)


@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToUI8Module())
def ElementwiseToDtypeI64ToUI8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100))


# ==============================================================================


class ElementwiseLog2Module(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit c434736

Please sign in to comment.