Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1942,7 +1942,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);

let results = (outs
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
);

let hasFolder = 1;
Expand Down
27 changes: 6 additions & 21 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
// Used to express accumulator results or compare results.
//===----------------------------------------------------------------------===//

def Tosa_UInt8 : UI<8>;
def Tosa_UInt16 : UI<16>;

def Tosa_Int4 : I<4>;
def Tosa_Int8 : I<8>;
def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
def Tosa_Int48 : I<48>;
def Tosa_Int64 : I<64>;

def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
Tosa_Int16,
Tosa_Int32,
Tosa_Int48,
Tosa_Int64]>;

def Tosa_Bool : I<1>;

// No unsigned unquantized int types.
def Tosa_Int : AnyTypeOf<[Tosa_Bool,
Tosa_UInt8,
Tosa_UInt16,
Tosa_SignedInt]>;
// The TOSA dialect allows more types than the TOSA standard to allow for
// experimentation. For historical reasons, signless is used in the place of
// signed.
// The TosaValidation pass can be used to check for standard conformance.
def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
AnySignlessInteger]>;

def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
Tosa_Int64]>;
Expand Down Expand Up @@ -172,9 +160,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<

def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;

//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
Expand Down
51 changes: 48 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);

bool isValidElementType(Type type);

SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
Expand Down Expand Up @@ -503,15 +505,58 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
return success();
}

bool TosaValidation::isValidElementType(Type type) {
if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
return false;
}
if (type.isF64()) {
return false;
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
switch (intTy.getWidth()) {
case 8:
case 16:
return true;
default:
return false;
}
} else {
// Signless - treated as signed.
switch (intTy.getWidth()) {
case 1:
case 4:
case 8:
case 16:
case 32:
case 48:
case 64:
return true;
default:
return false;
}
}
return false;
}
return true;
}

void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
if ((profile == TosaProfileEnum::BaseInference) &&
isa<FloatType>(getElementTypeOrSelf(operand))) {
auto elementTy = getElementTypeOrSelf(operand);
if (!isValidElementType(elementTy)) {
op->emitOpError() << "failed level check: element type " << elementTy
<< " is not legal";
return signalPassFailure();
}
if (getElementTypeOrSelf(operand).isF64()) {
}
for (Type resultTy : op->getResultTypes()) {
auto elementTy = getElementTypeOrSelf(resultTy);
if (!isValidElementType(elementTy)) {
op->emitOpError() << "failed level check: element type " << elementTy
<< " is not legal";
return signalPassFailure();
}
}
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {

// -----

func.func @test_const_i2(%arg0 : tensor<1xi2>) {
// expected-error@+1 {{'tosa.const' op failed level check: element type 'i2' is not legal}}
%0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
return
}

// -----

func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
// expected-error@+1 {{'tosa.const' op failed level check: element type 'ui32' is not legal}}
%0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
return
}

// -----

func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
Expand Down