Skip to content

[tosa] Change VariableOp to align with spec #142240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025

Conversation

Tai78641
Copy link
Contributor

This fixes Tosa VariableOp to align with spec 1.0

  • add var_shape attribute to store shape of variable type
  • change type attribute to store element type of variable type
  • add a builder so previous construction calls still work
  • fix up level check of rank to be on variable type instead of initial value which is optional
  • add level check of size for variable type
  • add lit tests for variable op's without initial values
  • add lit test for variable op with fixed rank but unknown dimension
  • add invalid lit test for variable op with unranked type

@llvmbot
Copy link
Member

llvmbot commented May 30, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This fixes Tosa VariableOp to align with spec 1.0

  • add var_shape attribute to store shape of variable type
  • change type attribute to store element type of variable type
  • add a builder so previous construction calls still work
  • fix up level check of rank to be on variable type instead of initial value which is optional
  • add level check of size for variable type
  • add lit tests for variable op's without initial values
  • add lit test for variable op with fixed rank but unknown dimension
  • add invalid lit test for variable op with unranked type

Patch is 24.84 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142240.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+10)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+11-4)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td (+6-1)
  • (modified) mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+106-27)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+2-9)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+48-22)
  • (modified) mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir (+15-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+17)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+5-4)
  • (modified) mlir/test/Dialect/Tosa/variables.mlir (+45)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 0aef4653b74ff..e048f8af7cc33 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -197,6 +197,16 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
                             input, paddings);
   }]>;
 
+// This builder is called on the TOSA variable operator with a variable type
+// and optional initial value. The builder will extract var_shape and element type
+// attributes from variable type.
+def Tosa_VariableOpBuilder : OpBuilder<
+  (ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
+  [{
+    buildVariableOp($_builder, $_state, name, variable_type, initial_value);
+  }]>;
+
+
 // Wrapper over base I32EnumAttr to set common fields.
 class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
      : I32EnumAttr<name, description, cases> {
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 6fa4aedc1f0b0..a15f073bc5fcb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -44,10 +44,14 @@ class PatternRewriter;
 
 namespace tosa {
 
-ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
-                            Attribute &attr);
-void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
-                     Attribute attr);
+ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser,
+                                              DenseElementsAttr &varShapeAttr,
+                                              TypeAttr &typeAttr,
+                                              Attribute &initialValueAttr);
+void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op,
+                                       DenseElementsAttr varShapeAttr,
+                                       TypeAttr typeAttr,
+                                       Attribute initialValueAttr);
 
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
 
@@ -172,6 +176,9 @@ std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
 Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
                            int32_t val = 0);
 
+// returns type of variable op
+RankedTensorType getVariableType(VariableOp variableOp);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 5f99162907949..c8f2907f8dd1b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -92,6 +92,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
 
   let arguments = (ins
     SymbolNameAttr:$name,
+    IndexElementsAttr:$var_shape,
     TypeAttr:$type,
     OptionalAttr<AnyAttr>:$initial_value
   );
@@ -101,12 +102,16 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
     Extension<[Tosa_EXT_VARIABLE]>,
   ];
 
+  let hasCustomAssemblyFormat = 1;
+
   let assemblyFormat = [{
     $name
     attr-dict
-    custom<TypeOrAttr>($type, $initial_value)
+    custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
   }];
 
+  let builders = [Tosa_VariableOpBuilder];
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
index 310566e692202..7dbccd19a0518 100644
--- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -26,8 +26,9 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
 
   LogicalResult matchAndRewrite(tosa::VariableOp op,
                                 PatternRewriter &rewriter) const final {
+    auto variableType = tosa::getVariableType(op);
     auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
-        op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
+        op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
         op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
     newVariable.setPrivate();
     rewriter.replaceOp(op, newVariable);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 93a6a8be48df7..6a1639104846e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
   return {&getBodyGraph()};
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA variable operator support.
+//===----------------------------------------------------------------------===//
+
+static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
+  return to_vector(llvm::map_range(shape, [](int64_t dim) {
+    return dim == -1 ? ShapedType::kDynamic : dim;
+  }));
+}
+
+// returns type of variable op
+RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
+  Type elementType = variableOp.getType();
+  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
+  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
+  return RankedTensorType::get(shape, elementType);
+}
+
 //===----------------------------------------------------------------------===//
 // Tosa dialect initialization.
 //===----------------------------------------------------------------------===//
@@ -177,42 +195,81 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Parsers and printers
 //===----------------------------------------------------------------------===//
 
-ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
-                                        Attribute &attr) {
+namespace {
+
+ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
+                                   DenseElementsAttr &varShapeAttr,
+                                   TypeAttr &typeAttr) {
+  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
+    if (!shapedType.hasRank())
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected ranked type";
+
+    auto elementType = shapedType.getElementType();
+    typeAttr = TypeAttr::get(elementType);
+    ArrayRef<int64_t> shape = shapedType.getShape();
+    Builder builder(parser.getContext());
+    varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+    return success();
+  }
+  return parser.emitError(parser.getCurrentLocation())
+         << "expected shaped type";
+}
+
+} // namespace
+
+// parses the optional initial value or type for a tosa variable
+//  with initial value:
+//    tosa.variable @name = dense<0.0> : tensor<1x8xf32>
+//
+//  without initial value:
+//    tosa.variable @name : tensor<1x8xf32>
+ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
+    OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
+    Attribute &initialValueAttr) {
   if (succeeded(parser.parseOptionalEqual())) {
-    if (failed(parser.parseAttribute(attr))) {
+    if (failed(parser.parseAttribute(initialValueAttr))) {
       return parser.emitError(parser.getCurrentLocation())
              << "expected attribute";
     }
-    if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
-      typeAttr = TypeAttr::get(typedAttr.getType());
+    if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
+      return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
+                                    typeAttr);
     }
-    return success();
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected Typed attr";
   }
 
-  Type type;
-  if (failed(parser.parseColonType(type))) {
-    return parser.emitError(parser.getCurrentLocation()) << "expected type";
+  initialValueAttr = nullptr;
+  Type parsedType;
+  if (failed(parser.parseColonType(parsedType))) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected type after colon";
   }
-  typeAttr = TypeAttr::get(type);
-
-  return success();
+  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
 }
 
-void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
-                                 Attribute attr) {
+void mlir::tosa::printVariableOpTypeOrInitialValue(
+    OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
+    TypeAttr typeAttr, Attribute initialValueAttr) {
   bool needsSpace = false;
-  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
-  if (!typedAttr || typedAttr.getType() != type.getValue()) {
+  auto typedAttr = dyn_cast_or_null<TypedAttr>(initialValueAttr);
+  if (!typedAttr) {
+    auto shape =
+        convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
+    Type elementType = typeAttr.getValue();
+    RankedTensorType tensorType =
+        RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
+    auto tensorTypeAttr = TypeAttr::get(tensorType);
     p << ": ";
-    p.printAttribute(type);
+    p.printAttribute(tensorTypeAttr);
     needsSpace = true; // subsequent attr value needs a space separator
   }
-  if (attr) {
+  if (initialValueAttr) {
     if (needsSpace)
       p << ' ';
     p << "= ";
-    p.printAttribute(attr);
+    p.printAttribute(initialValueAttr);
   }
 }
 
@@ -657,8 +714,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
            << symName << "' has not been declared by 'tosa.variable'";
 
   // Verify type and shape
-  Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
-  if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+  auto variableType = getVariableType(varOp.value());
+  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
+                                 "the input tensor")
           .failed())
     return failure();
 
@@ -1103,6 +1161,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.types.push_back(outputType);
 }
 
+static void buildVariableOp(OpBuilder &builder, OperationState &result,
+                            StringRef name, Type variableType,
+                            Attribute initialValue) {
+  const Location loc{result.location};
+  auto nameAttr = builder.getStringAttr(name);
+
+  auto shapedType = dyn_cast<ShapedType>(variableType);
+  if (!shapedType) {
+    (void)emitError(loc, "variable type must be a shaped type");
+    return;
+  }
+  if (!shapedType.hasRank()) {
+    (void)emitError(loc, "variable type must be a ranked type");
+    return;
+  }
+
+  auto elementType = shapedType.getElementType();
+  auto elementTypeAttr = TypeAttr::get(elementType);
+  ArrayRef<int64_t> shape = shapedType.getShape();
+  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+
+  result.addAttribute("name", nameAttr);
+  result.addAttribute("var_shape", varShapeAttr);
+  result.addAttribute("type", elementTypeAttr);
+  result.addAttribute("initial_value", initialValue);
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Return Type Inference.
 //===----------------------------------------------------------------------===//
@@ -1676,12 +1761,6 @@ LogicalResult tosa::PadOp::verify() {
   return success();
 }
 
-static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
-  return to_vector(llvm::map_range(shape, [](int64_t dim) {
-    return dim == -1 ? ShapedType::kDynamic : dim;
-  }));
-}
-
 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     SliceOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 1a896c1464e1c..de08e7e9a4394 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -215,15 +215,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
 
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
-  ::mlir::Attribute attr = op.getInitialValueAttr();
-  if (attr == nullptr)
-    return failure();
-
-  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
-    addType(getElementTypeOrSelf(typedAttr));
-    return success();
-  }
-  return failure();
+  addType(op.getType());
+  return success();
 }
 
 template <>
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f9db5dcb88b4c..ea862ecb49e4e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -238,10 +238,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  template <typename T>
-  bool levelCheckRank(Operation *op, const T &v,
+  // Perform the Level Rank check on the tensor type.
+  bool levelCheckRank(Operation *op, const Type typeToCheck,
                       const StringRef operandOrResult, int32_t highest_rank) {
-    if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+    if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
       if (!type.hasRank()) {
         op->emitOpError() << "failed level check: unranked tensor";
         return false;
@@ -255,10 +255,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  // Perform the Level tensor size check on the input tensor.
-  bool levelCheckSize(Operation *op, const Value &v,
+  // Perform the Level Rank check on the tensor value.
+  bool levelCheckRank(Operation *op, const Value &v,
+                      const StringRef operandOrResult, int32_t highest_rank) {
+    return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
+  }
+
+  // Perform the Level tensor size check on the tensor type.
+  bool levelCheckSize(Operation *op, const Type &typeToCheck,
                       const StringRef operandOrResult);
 
+  // Perform the Level tensor size check on the tensor value.
+  bool levelCheckSize(Operation *op, const Value &v,
+                      const StringRef operandOrResult) {
+    return levelCheckSize(op, v.getType(), operandOrResult);
+  }
+
   // Level check sizes of all operands and results of the operation.
   template <typename T>
   bool levelCheckSizes(T tosaOp) {
@@ -284,15 +296,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         return false;
     }
 
-    if (!op->getAttrs().empty()) {
-      for (NamedAttribute attr : op->getAttrs()) {
-        if (auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
-          if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK))
-            return false;
-        }
-      }
-    }
-
     for (auto v : op->getResults()) {
       if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
         return false;
@@ -596,6 +599,26 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
   return true;
 }
 
+template <>
+bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+  auto op = tosaOp.getOperation();
+  auto variableType = getVariableType(tosaOp);
+  if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
+    return false;
+
+  return true;
+}
+
+template <>
+bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+  auto op = tosaOp.getOperation();
+  auto variableType = getVariableType(tosaOp);
+  if (!levelCheckSize(op, variableType, "variable type"))
+    return false;
+
+  return true;
+}
+
 bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
 #define CHECK_RANKS_AND_SIZES(tosaOp)                                          \
   if (isa<tosa::tosaOp##Op>(op)) {                                             \
@@ -714,10 +737,10 @@ bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   return true;
 }
 
-// Perform the Level tensor size check
-bool TosaValidation::levelCheckSize(Operation *op, const Value &v,
+// Perform the Level tensor size check on the tensor type.
+bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
                                     const StringRef operandOrResult) {
-  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+  if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
     if (!type.hasRank()) {
       op->emitOpError() << "failed level check: unranked tensor";
       return false;
@@ -800,18 +823,21 @@ inline bool CompatibleTypes(const mlir::Type &type,
 }
 
 bool TosaValidation::CheckVariable(Operation *op) {
-  if (isa<mlir::tosa::VariableOp>(op)) {
-    mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
+  if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
+    mlir::StringAttr nameAttr = variableOp.getNameAttr();
 
     if (variablesMap.count(nameAttr)) {
       op->emitOpError() << "name has already been declared";
       return false;
     }
 
-    auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
-    mlir::Type type = typeAttr.getValue();
+    auto elementType = variableOp.getType();
+    DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
+    SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
+    RankedTensorType variableType =
+        RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
 
-    variablesMap[nameAttr] = type;
+    variablesMap[nameAttr] = variableType;
   }
 
   return true;
diff --git a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
index 365b05ff084da..d2092753f1f58 100644
--- a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
+++ b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
+// RUN: mlir-opt --tosa-to-mlprogram %s -split-input-file -o -| FileCheck %s
 
 module {
   // CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
@@ -10,4 +10,18 @@ module {
     %0 = tosa.variable_read @var_x : tensor<1xf32>
     return %0 : tensor<1xf32>
   }
+}
+
+// -----
+
+module {
+  // CHECK: ml_program.global private mutable @var_x : tensor<f32>
+  tosa.variable @var_x : tensor<f32>
+  func.func @test_stateful_ops(%arg0: tensor<f32>) -> (tensor<f32>) {
+    // CHECK: ml_program.global_store @var_x = %arg0 : tensor<f32>
+    tosa.variable_write @var_x, %arg0 : tensor<f32>
+    // CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<f32>
+    %0 = tosa.variable_read @var_x : tensor<f32>
+    return %0 : tensor<f32>
+  }
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c41f079ec526c..05505c3671674 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -564,6 +564,23 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
 
 // -----
 
+func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
+  tosa.variable @stored_var : tensor<*xi8>
+  // expected-error@+1 {{custom op 'tosa.variable' expected ranked type}}
+  return
+}
+
+// -----
+
+func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
+  // expected-error@+1 {{elements literal type must have static shape}}
+  tosa.variable @stored_var = dense<0> : tensor<*xi8>
+  // expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
+  return
+}
+
+// -----
+
 func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
   // expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7d0a0e1fa4ea..223bf3b635e18 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -443,7 +443,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso
 
 // -----
 func.func @...
[truncated]

@Jerry-Ge Jerry-Ge requested review from GeorgeARM and lhutton1 May 31, 2025 01:34
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted a small nitpick, otherwise LGTM thanks!

This fixes Tosa VariableOp to align with spec 1.0
  - add var_shape attribute to store shape of variable type
  - change type attribute to store element type of variable type
  - add a builder so previous construction calls still work
  - fix up level check of rank to be on variable type instead of
    initial value which is optional
  - add level check of size for variable type
  - add lit tests for variable op's without initial values
  - add lit test for variable op with fixed rank but unknown dimension
  - add invalid lit test for variable op with unranked type

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: Icbbd751666870a94d4902163f7e840395e2aea52
@Tai78641 Tai78641 force-pushed the pr_sync_variable_op branch from 786379b to 0c8523b Compare June 2, 2025 17:20
@lhutton1 lhutton1 merged commit 04b63ac into llvm:main Jun 3, 2025
11 checks passed
@Tai78641 Tai78641 deleted the pr_sync_variable_op branch June 3, 2025 16:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants