-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[Coro][WebAssembly] Add tail-call check for async lowering #81481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Coro][WebAssembly] Add tail-call check for async lowering #81481
Conversation
@llvm/pr-subscribers-coroutines @llvm/pr-subscribers-llvm-analysis Author: Yuta Saito (kateinoigakukun) ChangesThis patch fixes a verifier error when async lowering is used for WebAssembly target without tail-call feature. This missing check was revealed by b1ac052, which removed inlining of the musttail'ed call and it started leaving the invalid call at the verification stage. Additionally, Full diff: https://github.com/llvm/llvm-project/pull/81481.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d5db96e86b804..13379cc126a40c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -367,10 +367,6 @@ class TargetTransformInfoImplBase {
bool supportsTailCalls() const { return true; }
- bool supportsTailCallFor(const CallBase *CB) const {
- return supportsTailCalls();
- }
-
bool enableAggressiveInterleaving(bool LoopHasReductions) const {
return false;
}
@@ -1427,6 +1423,10 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
I, Ops, TargetTransformInfo::TCK_SizeAndLatency);
return Cost >= TargetTransformInfo::TCC_Expensive;
}
+
+ bool supportsTailCallFor(const CallBase *CB) const {
+ return static_cast<const T *>(this)->supportsTailCalls();
+ }
};
} // namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index e69c718f0ae3ac..994871eb126884 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -3064,7 +3064,7 @@ static void doRematerializations(
}
void coro::buildCoroutineFrame(
- Function &F, Shape &Shape,
+ Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback) {
// Don't eliminate swifterror in async functions that won't be split.
if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
@@ -3100,7 +3100,7 @@ void coro::buildCoroutineFrame(
SmallVector<Value *, 8> Args(AsyncEnd->args());
auto Arguments = ArrayRef<Value *>(Args).drop_front(3);
auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn,
- Arguments, Builder);
+ TTI, Arguments, Builder);
splitAround(Call, "MustTailCall.Before.CoroEnd");
}
}
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index fb16a4090689b4..388cf8d2aee71c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -12,6 +12,7 @@
#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H
#include "CoroInstr.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/IRBuilder.h"
namespace llvm {
@@ -272,9 +273,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
bool defaultMaterializable(Instruction &V);
void buildCoroutineFrame(
- Function &F, Shape &Shape,
+ Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+ TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
} // End namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index aed4cd027d0338..47367d0b84edec 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -1746,6 +1746,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
}
CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+ TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments,
IRBuilder<> &Builder) {
auto *FnTy = MustTailCallFn->getFunctionType();
@@ -1755,14 +1756,18 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
coerceArguments(Builder, FnTy, Arguments, CallArgs);
auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs);
- TailCall->setTailCallKind(CallInst::TCK_MustTail);
+ // Skip targets which don't support tail call.
+ if (TTI.supportsTailCallFor(TailCall)) {
+ TailCall->setTailCallKind(CallInst::TCK_MustTail);
+ }
TailCall->setDebugLoc(Loc);
TailCall->setCallingConv(MustTailCallFn->getCallingConv());
return TailCall;
}
static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
- SmallVectorImpl<Function *> &Clones) {
+ SmallVectorImpl<Function *> &Clones,
+ TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Async);
assert(Clones.empty());
// Reset various things that the optimizer might have decided it
@@ -1837,7 +1842,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVector<Value *, 8> Args(Suspend->args());
auto FnArgs = ArrayRef<Value *>(Args).drop_front(
CoroSuspendAsyncInst::MustTailCallFuncArg + 1);
- coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder);
+ coro::createMustTailCall(Suspend->getDebugLoc(), Fn, TTI, FnArgs, Builder);
Builder.CreateRetVoid();
// Replace the lvm.coro.async.resume intrisic call.
@@ -2010,7 +2015,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
return Shape;
simplifySuspendPoints(Shape);
- buildCoroutineFrame(F, Shape, MaterializableCallback);
+ buildCoroutineFrame(F, Shape, TTI, MaterializableCallback);
replaceFrameSizeAndAlignment(Shape);
// If there are no suspend points, no split required, just remove
@@ -2023,7 +2028,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
break;
case coro::ABI::Async:
- splitAsyncCoroutine(F, Shape, Clones);
+ splitAsyncCoroutine(F, Shape, Clones, TTI);
break;
case coro::ABI::Retcon:
case coro::ABI::RetconOnce:
diff --git a/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
new file mode 100644
index 00000000000000..a64c6b34243ec3
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
@@ -0,0 +1,36 @@
+; Tests that coro-split will convert coro.resume followed by a suspend to a
+; musttail call.
+; RUN: opt < %s -O0 -S -mtriple=wasm32-unknown-unknown | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
+target triple = "wasm32-unknown-wasi"
+
+%swift.async_func_pointer = type <{ i32, i32 }>
+
+@checkTu = global %swift.async_func_pointer <{ i32 ptrtoint (ptr @check to i32), i32 8 }>
+
+define swiftcc void @check(ptr %0) {
+entry:
+ %1 = call token @llvm.coro.id.async(i32 0, i32 0, i32 0, ptr @checkTu)
+ %2 = call ptr @llvm.coro.begin(token %1, ptr null)
+ %3 = call ptr @llvm.coro.async.resume()
+ store ptr %3, ptr %0, align 4
+ %4 = call { ptr, i32 } (i32, ptr, ptr, ...) @llvm.coro.suspend.async.sl_p0i32s(i32 0, ptr %3, ptr @__swift_async_resume_project_context, ptr @check.0, ptr null, ptr null)
+ ret void
+}
+
+declare swiftcc void @check.0()
+declare { ptr, i32 } @llvm.coro.suspend.async.sl_p0i32s(i32, ptr, ptr, ...)
+declare token @llvm.coro.id.async(i32, i32, i32, ptr)
+declare ptr @llvm.coro.begin(token, ptr writeonly)
+declare ptr @llvm.coro.async.resume()
+
+define ptr @__swift_async_resume_project_context(ptr %0) {
+entry:
+ ret ptr null
+}
+
+; Verify that the resume call is not marked as musttail.
+; CHECK-LABEL: define swiftcc void @check(
+; CHECK-NOT: musttail call swiftcc void @check.0()
+; CHECK: call swiftcc void @check.0()
|
@llvm/pr-subscribers-coroutines Author: Yuta Saito (kateinoigakukun) ChangesThis patch fixes a verifier error when async lowering is used for WebAssembly target without tail-call feature. This missing check was revealed by b1ac052, which removed inlining of the musttail'ed call and it started leaving the invalid call at the verification stage. Additionally, Full diff: https://github.com/llvm/llvm-project/pull/81481.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d5db96e86b804..13379cc126a40c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -367,10 +367,6 @@ class TargetTransformInfoImplBase {
bool supportsTailCalls() const { return true; }
- bool supportsTailCallFor(const CallBase *CB) const {
- return supportsTailCalls();
- }
-
bool enableAggressiveInterleaving(bool LoopHasReductions) const {
return false;
}
@@ -1427,6 +1423,10 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
I, Ops, TargetTransformInfo::TCK_SizeAndLatency);
return Cost >= TargetTransformInfo::TCC_Expensive;
}
+
+ bool supportsTailCallFor(const CallBase *CB) const {
+ return static_cast<const T *>(this)->supportsTailCalls();
+ }
};
} // namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index e69c718f0ae3ac..994871eb126884 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -3064,7 +3064,7 @@ static void doRematerializations(
}
void coro::buildCoroutineFrame(
- Function &F, Shape &Shape,
+ Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback) {
// Don't eliminate swifterror in async functions that won't be split.
if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
@@ -3100,7 +3100,7 @@ void coro::buildCoroutineFrame(
SmallVector<Value *, 8> Args(AsyncEnd->args());
auto Arguments = ArrayRef<Value *>(Args).drop_front(3);
auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn,
- Arguments, Builder);
+ TTI, Arguments, Builder);
splitAround(Call, "MustTailCall.Before.CoroEnd");
}
}
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index fb16a4090689b4..388cf8d2aee71c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -12,6 +12,7 @@
#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H
#include "CoroInstr.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/IRBuilder.h"
namespace llvm {
@@ -272,9 +273,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
bool defaultMaterializable(Instruction &V);
void buildCoroutineFrame(
- Function &F, Shape &Shape,
+ Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+ TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
} // End namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index aed4cd027d0338..47367d0b84edec 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -1746,6 +1746,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
}
CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
+ TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments,
IRBuilder<> &Builder) {
auto *FnTy = MustTailCallFn->getFunctionType();
@@ -1755,14 +1756,18 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
coerceArguments(Builder, FnTy, Arguments, CallArgs);
auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs);
- TailCall->setTailCallKind(CallInst::TCK_MustTail);
+ // Skip targets which don't support tail call.
+ if (TTI.supportsTailCallFor(TailCall)) {
+ TailCall->setTailCallKind(CallInst::TCK_MustTail);
+ }
TailCall->setDebugLoc(Loc);
TailCall->setCallingConv(MustTailCallFn->getCallingConv());
return TailCall;
}
static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
- SmallVectorImpl<Function *> &Clones) {
+ SmallVectorImpl<Function *> &Clones,
+ TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Async);
assert(Clones.empty());
// Reset various things that the optimizer might have decided it
@@ -1837,7 +1842,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVector<Value *, 8> Args(Suspend->args());
auto FnArgs = ArrayRef<Value *>(Args).drop_front(
CoroSuspendAsyncInst::MustTailCallFuncArg + 1);
- coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder);
+ coro::createMustTailCall(Suspend->getDebugLoc(), Fn, TTI, FnArgs, Builder);
Builder.CreateRetVoid();
// Replace the lvm.coro.async.resume intrisic call.
@@ -2010,7 +2015,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
return Shape;
simplifySuspendPoints(Shape);
- buildCoroutineFrame(F, Shape, MaterializableCallback);
+ buildCoroutineFrame(F, Shape, TTI, MaterializableCallback);
replaceFrameSizeAndAlignment(Shape);
// If there are no suspend points, no split required, just remove
@@ -2023,7 +2028,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
break;
case coro::ABI::Async:
- splitAsyncCoroutine(F, Shape, Clones);
+ splitAsyncCoroutine(F, Shape, Clones, TTI);
break;
case coro::ABI::Retcon:
case coro::ABI::RetconOnce:
diff --git a/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
new file mode 100644
index 00000000000000..a64c6b34243ec3
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
@@ -0,0 +1,36 @@
+; Tests that coro-split will convert coro.resume followed by a suspend to a
+; musttail call.
+; RUN: opt < %s -O0 -S -mtriple=wasm32-unknown-unknown | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
+target triple = "wasm32-unknown-wasi"
+
+%swift.async_func_pointer = type <{ i32, i32 }>
+
+@checkTu = global %swift.async_func_pointer <{ i32 ptrtoint (ptr @check to i32), i32 8 }>
+
+define swiftcc void @check(ptr %0) {
+entry:
+ %1 = call token @llvm.coro.id.async(i32 0, i32 0, i32 0, ptr @checkTu)
+ %2 = call ptr @llvm.coro.begin(token %1, ptr null)
+ %3 = call ptr @llvm.coro.async.resume()
+ store ptr %3, ptr %0, align 4
+ %4 = call { ptr, i32 } (i32, ptr, ptr, ...) @llvm.coro.suspend.async.sl_p0i32s(i32 0, ptr %3, ptr @__swift_async_resume_project_context, ptr @check.0, ptr null, ptr null)
+ ret void
+}
+
+declare swiftcc void @check.0()
+declare { ptr, i32 } @llvm.coro.suspend.async.sl_p0i32s(i32, ptr, ptr, ...)
+declare token @llvm.coro.id.async(i32, i32, i32, ptr)
+declare ptr @llvm.coro.begin(token, ptr writeonly)
+declare ptr @llvm.coro.async.resume()
+
+define ptr @__swift_async_resume_project_context(ptr %0) {
+entry:
+ ret ptr null
+}
+
+; Verify that the resume call is not marked as musttail.
+; CHECK-LABEL: define swiftcc void @check(
+; CHECK-NOT: musttail call swiftcc void @check.0()
+; CHECK: call swiftcc void @check.0()
|
b25a97e
to
f37434b
Compare
This patch fixes a verifier error when async lowering is used for WebAssembly target without tail-call feature. This missing check was revealed by b1ac052, which removed inlining of the musttail'ed call and it started leaving the invalid call at the verification stage. Additionally, `TTI::supportsTailCallFor` did not respect the concrete TTI's `supportsTailCalls` implementation, so it always returned true even though `supportsTailCalls` returned false, so this patch also fixes the wrong CRTP base class implementation.
f37434b
to
34784e0
Compare
Gentle ping? @aschwaighofer |
This PR is blocking Swift CI and this change only affects Swift with WebAssembly, so I think it's safe to merge this. Please let me know if you have any issues. |
This patch fixes a verifier error when async lowering is used for WebAssembly target without tail-call feature. This missing check was revealed by b1ac052, which removed inlining of the musttail'ed call and it started leaving the invalid call at the verification stage. Additionally, `TTI::supportsTailCallFor` did not respect the concrete TTI's `supportsTailCalls` implementation, so it always returned true even though `supportsTailCalls` returned false, so this patch also fixes the wrong CRTP base class implementation.
This patch fixes a verifier error when async lowering is used for WebAssembly target without tail-call feature. This missing check was revealed by b1ac052, which removed inlining of the musttail'ed call and it started leaving the invalid call at the verification stage. Additionally,
TTI::supportsTailCallFor
did not respect the concrete TTI'ssupportsTailCalls
implementation, so it always returned true even thoughsupportsTailCalls
returned false, so this patch also fixes the wrong CRTP base class implementation.