Skip to content

Conversation

@lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Jun 19, 2025

Similarly to tosa.cond_if, this patch checks that the cond/body regions of tosa.while_loop are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.

Note: this change is dependent on #143772

Copy link
Contributor

@FranklandJack FranklandJack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % a few pedantic comments :D

Similarly to `tosa.cond_if`, this patch checks that the cond/body
regions of `tosa.while_loop` are isolated from above. This is required
since the specification requires all values used in the cond/body
regions are explicitly declared within the regions.

Change-Id: Ia7396b9811db54805ec33befd24ab97d1b605905
@lhutton1 lhutton1 force-pushed the fix-while-loop-isolated branch from 7bd13d2 to 516c47a Compare July 21, 2025 11:56
@lhutton1 lhutton1 marked this pull request as ready for review July 21, 2025 11:59
@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

Similarly to tosa.cond_if, this patch checks that the cond/body regions of tosa.while_loop are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.

Note: this change is dependent on #143772


Full diff: https://github.com/llvm/llvm-project/pull/144865.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+32-25)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+57)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 48ec28acfaaaa..32b5fb63a6ece 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1200,6 +1200,28 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
   });
 }
 
+static bool isRegionIsolatedFromAbove(Region &regionToCheck) {
+  bool noLiveInValue = true;
+  regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
+    if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
+      noLiveInValue = false;
+      return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+  return noLiveInValue;
+}
+
+LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
+                                  StringRef regionName) {
+  if (isRegionIsolatedFromAbove(regionToCheck))
+    return success();
+  op->emitOpError()
+      << "is not conformant to the TOSA specification. It requires the '"
+      << regionName << "' region is isolated from above.\n";
+  return failure();
+}
+
 bool checkErrorIfCondIf(Operation *op) {
   auto ifOp = dyn_cast<tosa::IfOp>(op);
   if (!ifOp)
@@ -1236,32 +1258,17 @@ bool checkErrorIfCondIf(Operation *op) {
   // used in then/else regions (see 'simplified' example above), so it
   // must be rewritten to use the generic syntax in order to be conformant
   // to the specification.
+  return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
+         failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
+}
 
-  // Returns true if the region uses no external input operands.
-  auto isIsolatedRegion = [](Region &regionToCheck) -> bool {
-    bool noLiveInValue = true;
-    regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *opInRegion) {
-      if (!isOpIsolatedWithinRegion(opInRegion, &regionToCheck)) {
-        noLiveInValue = false;
-        return WalkResult::interrupt();
-      }
-      return WalkResult::advance();
-    });
-    return noLiveInValue;
-  };
-
-  auto checkIsolatedRegion = [&](Region &regionToCheck,
-                                 StringRef regionName) -> LogicalResult {
-    if (isIsolatedRegion(regionToCheck))
-      return success();
-    op->emitOpError()
-        << "is not conformant to the TOSA specification. It requires the '"
-        << regionName << "' region is isolated from above.\n";
-    return failure();
-  };
+bool checkErrorIfWhileLoop(Operation *op) {
+  auto whileOp = dyn_cast<tosa::WhileOp>(op);
+  if (!whileOp)
+    return true;
 
-  return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
-         failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
+  return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
+         failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
 }
 
 bool checkErrorIfScatter(Operation *op) {
@@ -1293,7 +1300,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
       !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
       !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
-      !checkErrorIfScatter(op))
+      !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 8924dd9885827..eb25011ff3a9d 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -280,3 +280,60 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
     }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
+
+// -----
+
+func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}}
+  %1 = "tosa.while_loop"(%0) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  },  {
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return
+}
+
+// -----
+
+func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'body' region is isolated from above.}}
+  %1 = "tosa.while_loop"(%0) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.greater_equal"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %4 = "tosa.logical_not"(%3) : (tensor<i1>) -> tensor<i1>
+    tosa.yield %4 : tensor<i1>
+  },  {
+  ^bb0(%arg3: tensor<i32>):
+    %3 = "tosa.add"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return
+}
+
+// -----
+
+// Check isolated while_loops are valid
+func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    "tosa.yield"(%3) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
+  }) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

Similarly to tosa.cond_if, this patch checks that the cond/body regions of tosa.while_loop are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.

Note: this change is dependent on #143772


Full diff: https://github.com/llvm/llvm-project/pull/144865.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+32-25)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+57)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 48ec28acfaaaa..32b5fb63a6ece 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1200,6 +1200,28 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
   });
 }
 
+static bool isRegionIsolatedFromAbove(Region &regionToCheck) {
+  bool noLiveInValue = true;
+  regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
+    if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
+      noLiveInValue = false;
+      return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+  return noLiveInValue;
+}
+
+LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
+                                  StringRef regionName) {
+  if (isRegionIsolatedFromAbove(regionToCheck))
+    return success();
+  op->emitOpError()
+      << "is not conformant to the TOSA specification. It requires the '"
+      << regionName << "' region is isolated from above.\n";
+  return failure();
+}
+
 bool checkErrorIfCondIf(Operation *op) {
   auto ifOp = dyn_cast<tosa::IfOp>(op);
   if (!ifOp)
@@ -1236,32 +1258,17 @@ bool checkErrorIfCondIf(Operation *op) {
   // used in then/else regions (see 'simplified' example above), so it
   // must be rewritten to use the generic syntax in order to be conformant
   // to the specification.
+  return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
+         failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
+}
 
-  // Returns true if the region uses no external input operands.
-  auto isIsolatedRegion = [](Region &regionToCheck) -> bool {
-    bool noLiveInValue = true;
-    regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *opInRegion) {
-      if (!isOpIsolatedWithinRegion(opInRegion, &regionToCheck)) {
-        noLiveInValue = false;
-        return WalkResult::interrupt();
-      }
-      return WalkResult::advance();
-    });
-    return noLiveInValue;
-  };
-
-  auto checkIsolatedRegion = [&](Region &regionToCheck,
-                                 StringRef regionName) -> LogicalResult {
-    if (isIsolatedRegion(regionToCheck))
-      return success();
-    op->emitOpError()
-        << "is not conformant to the TOSA specification. It requires the '"
-        << regionName << "' region is isolated from above.\n";
-    return failure();
-  };
+bool checkErrorIfWhileLoop(Operation *op) {
+  auto whileOp = dyn_cast<tosa::WhileOp>(op);
+  if (!whileOp)
+    return true;
 
-  return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
-         failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
+  return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
+         failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
 }
 
 bool checkErrorIfScatter(Operation *op) {
@@ -1293,7 +1300,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
       !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
       !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
-      !checkErrorIfScatter(op))
+      !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 8924dd9885827..eb25011ff3a9d 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -280,3 +280,60 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
     }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
+
+// -----
+
+func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}}
+  %1 = "tosa.while_loop"(%0) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  },  {
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return
+}
+
+// -----
+
+func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'body' region is isolated from above.}}
+  %1 = "tosa.while_loop"(%0) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.greater_equal"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %4 = "tosa.logical_not"(%3) : (tensor<i1>) -> tensor<i1>
+    tosa.yield %4 : tensor<i1>
+  },  {
+  ^bb0(%arg3: tensor<i32>):
+    %3 = "tosa.add"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return
+}
+
+// -----
+
+// Check isolated while_loops are valid
+func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    "tosa.yield"(%3) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
+  }) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+  return
+}

@lhutton1
Copy link
Contributor Author

Rebased after the cond_if change was merged. PTAL when you have some time @udaya-ranga, @FranklandJack

Copy link
Contributor

@FranklandJack FranklandJack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for addressing my feedback!

@lhutton1 lhutton1 merged commit d8857d5 into llvm:main Jul 22, 2025
13 checks passed
@lhutton1 lhutton1 deleted the fix-while-loop-isolated branch July 22, 2025 13:45
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…4865)

Similarly to `tosa.cond_if`, this patch checks that the cond/body
regions of `tosa.while_loop` are isolated from above. This is required
since the specification requires all values used in the cond/body
regions are explicitly declared within the regions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants