-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][tosa] Check for isolated regions in tosa.while_loop
#144865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % a few pedantic comments :D
Similarly to `tosa.cond_if`, this patch checks that the cond/body regions of `tosa.while_loop` are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions. Change-Id: Ia7396b9811db54805ec33befd24ab97d1b605905
7bd13d2 to
516c47a
Compare
|
@llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesSimilarly to
Full diff: https://github.com/llvm/llvm-project/pull/144865.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 48ec28acfaaaa..32b5fb63a6ece 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1200,6 +1200,28 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
});
}
+static bool isRegionIsolatedFromAbove(Region ®ionToCheck) {
+ bool noLiveInValue = true;
+ regionToCheck.walk([&noLiveInValue, ®ionToCheck](Operation *op) {
+ if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
+ noLiveInValue = false;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return noLiveInValue;
+}
+
+LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
+ StringRef regionName) {
+ if (isRegionIsolatedFromAbove(regionToCheck))
+ return success();
+ op->emitOpError()
+ << "is not conformant to the TOSA specification. It requires the '"
+ << regionName << "' region is isolated from above.\n";
+ return failure();
+}
+
bool checkErrorIfCondIf(Operation *op) {
auto ifOp = dyn_cast<tosa::IfOp>(op);
if (!ifOp)
@@ -1236,32 +1258,17 @@ bool checkErrorIfCondIf(Operation *op) {
// used in then/else regions (see 'simplified' example above), so it
// must be rewritten to use the generic syntax in order to be conformant
// to the specification.
+ return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
+ failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
+}
- // Returns true if the region uses no external input operands.
- auto isIsolatedRegion = [](Region ®ionToCheck) -> bool {
- bool noLiveInValue = true;
- regionToCheck.walk([&noLiveInValue, ®ionToCheck](Operation *opInRegion) {
- if (!isOpIsolatedWithinRegion(opInRegion, ®ionToCheck)) {
- noLiveInValue = false;
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
- return noLiveInValue;
- };
-
- auto checkIsolatedRegion = [&](Region ®ionToCheck,
- StringRef regionName) -> LogicalResult {
- if (isIsolatedRegion(regionToCheck))
- return success();
- op->emitOpError()
- << "is not conformant to the TOSA specification. It requires the '"
- << regionName << "' region is isolated from above.\n";
- return failure();
- };
+bool checkErrorIfWhileLoop(Operation *op) {
+ auto whileOp = dyn_cast<tosa::WhileOp>(op);
+ if (!whileOp)
+ return true;
- return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
- failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
+ return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
+ failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
}
bool checkErrorIfScatter(Operation *op) {
@@ -1293,7 +1300,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
!checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
- !checkErrorIfScatter(op))
+ !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 8924dd9885827..eb25011ff3a9d 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -280,3 +280,60 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
+
+// -----
+
+func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}}
+ %1 = "tosa.while_loop"(%0) ({
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ tosa.yield %3 : tensor<i1>
+ }, {
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3 : tensor<i32>
+ }) : (tensor<i32>) -> (tensor<i32>)
+ return
+}
+
+// -----
+
+func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'body' region is isolated from above.}}
+ %1 = "tosa.while_loop"(%0) ({
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.greater_equal"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %4 = "tosa.logical_not"(%3) : (tensor<i1>) -> tensor<i1>
+ tosa.yield %4 : tensor<i1>
+ }, {
+ ^bb0(%arg3: tensor<i32>):
+ %3 = "tosa.add"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3 : tensor<i32>
+ }) : (tensor<i32>) -> (tensor<i32>)
+ return
+}
+
+// -----
+
+// Check isolated while_loops are valid
+func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ "tosa.yield"(%3) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ "tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
+ }) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+ return
+}
|
|
@llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesSimilarly to
Full diff: https://github.com/llvm/llvm-project/pull/144865.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 48ec28acfaaaa..32b5fb63a6ece 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1200,6 +1200,28 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
});
}
+static bool isRegionIsolatedFromAbove(Region ®ionToCheck) {
+ bool noLiveInValue = true;
+ regionToCheck.walk([&noLiveInValue, ®ionToCheck](Operation *op) {
+ if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
+ noLiveInValue = false;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return noLiveInValue;
+}
+
+LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
+ StringRef regionName) {
+ if (isRegionIsolatedFromAbove(regionToCheck))
+ return success();
+ op->emitOpError()
+ << "is not conformant to the TOSA specification. It requires the '"
+ << regionName << "' region is isolated from above.\n";
+ return failure();
+}
+
bool checkErrorIfCondIf(Operation *op) {
auto ifOp = dyn_cast<tosa::IfOp>(op);
if (!ifOp)
@@ -1236,32 +1258,17 @@ bool checkErrorIfCondIf(Operation *op) {
// used in then/else regions (see 'simplified' example above), so it
// must be rewritten to use the generic syntax in order to be conformant
// to the specification.
+ return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
+ failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
+}
- // Returns true if the region uses no external input operands.
- auto isIsolatedRegion = [](Region ®ionToCheck) -> bool {
- bool noLiveInValue = true;
- regionToCheck.walk([&noLiveInValue, ®ionToCheck](Operation *opInRegion) {
- if (!isOpIsolatedWithinRegion(opInRegion, ®ionToCheck)) {
- noLiveInValue = false;
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
- return noLiveInValue;
- };
-
- auto checkIsolatedRegion = [&](Region ®ionToCheck,
- StringRef regionName) -> LogicalResult {
- if (isIsolatedRegion(regionToCheck))
- return success();
- op->emitOpError()
- << "is not conformant to the TOSA specification. It requires the '"
- << regionName << "' region is isolated from above.\n";
- return failure();
- };
+bool checkErrorIfWhileLoop(Operation *op) {
+ auto whileOp = dyn_cast<tosa::WhileOp>(op);
+ if (!whileOp)
+ return true;
- return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
- failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
+ return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
+ failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
}
bool checkErrorIfScatter(Operation *op) {
@@ -1293,7 +1300,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
!checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
- !checkErrorIfScatter(op))
+ !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 8924dd9885827..eb25011ff3a9d 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -280,3 +280,60 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
+
+// -----
+
+func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}}
+ %1 = "tosa.while_loop"(%0) ({
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ tosa.yield %3 : tensor<i1>
+ }, {
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3 : tensor<i32>
+ }) : (tensor<i32>) -> (tensor<i32>)
+ return
+}
+
+// -----
+
+func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'body' region is isolated from above.}}
+ %1 = "tosa.while_loop"(%0) ({
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.greater_equal"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %4 = "tosa.logical_not"(%3) : (tensor<i1>) -> tensor<i1>
+ tosa.yield %4 : tensor<i1>
+ }, {
+ ^bb0(%arg3: tensor<i32>):
+ %3 = "tosa.add"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3 : tensor<i32>
+ }) : (tensor<i32>) -> (tensor<i32>)
+ return
+}
+
+// -----
+
+// Check isolated while_loops are valid
+func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ "tosa.yield"(%3) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ "tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
+ }) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+ return
+}
|
|
Rebased after the cond_if change was merged. PTAL when you have some time @udaya-ranga, @FranklandJack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for addressing my feedback!
…4865) Similarly to `tosa.cond_if`, this patch checks that the cond/body regions of `tosa.while_loop` are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.
Similarly to
tosa.cond_if, this patch checks that the cond/body regions oftosa.while_loopare isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.Note: this change is dependent on #143772