Skip to content

Commit 1c4a085

Browse files
authored
[Stack Switching] Make continuations non-castable (#7980)
This recently changed in the spec. * Add new `isCastable()` on types. * Avoid adding casts on uncastable things in GUFA. * Avoid and fix uncastable things in the fuzzer. * Handle `br_ifs` in binary writing using scratch locals when needed. The `br_if` change is the only major work. Before, we would use casts to fix things, as follows: * Single values were just cast after the `br_if`. * Tuples were stashed to locals after the `br_if`, then reloaded+cast. After, we do this: * Single CASTABLE values are just cast after the `br_if`. * Anything else - uncastable, or Tuples - is handled with locals. We stash BEFORE the `br_if` now, then drop the `br_if` output, then reload (this is longer than before due to the drops, but avoids casts).
1 parent 0f59f55 commit 1c4a085

File tree

12 files changed

+591
-131
lines changed

12 files changed

+591
-131
lines changed

src/passes/GUFA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,8 @@ struct GUFAOptimizer
371371
bool optimized = false;
372372

373373
void visitExpression(Expression* curr) {
374-
if (!curr->type.isRef()) {
375-
// Ignore anything we cannot infer a type for.
374+
// Ignore anything we cannot emit a cast for.
375+
if (!curr->type.isCastable()) {
376376
return;
377377
}
378378

src/tools/fuzzing.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ class TranslateToFuzzReader {
542542
// Getters for Types
543543
Type getSingleConcreteType();
544544
Type getReferenceType();
545+
Type getCastableReferenceType();
545546
Type getEqReferenceType();
546547
Type getMVPType();
547548
Type getTupleType();

src/tools/fuzzing/fuzzing.cpp

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2424,10 +2424,12 @@ Expression* TranslateToFuzzReader::_makeConcrete(Type type) {
24242424
options.add(FeatureSet::ReferenceTypes | FeatureSet::GC,
24252425
&Self::makeCompoundRef);
24262426
}
2427-
// Exact casts are only allowed with custom descriptors enabled.
2428-
if (type.isInexact() || wasm.features.hasCustomDescriptors()) {
2429-
options.add(FeatureSet::ReferenceTypes | FeatureSet::GC,
2430-
&Self::makeRefCast);
2427+
if (type.isCastable()) {
2428+
// Exact casts are only allowed with custom descriptors enabled.
2429+
if (type.isInexact() || wasm.features.hasCustomDescriptors()) {
2430+
options.add(FeatureSet::ReferenceTypes | FeatureSet::GC,
2431+
&Self::makeRefCast);
2432+
}
24312433
}
24322434
if (heapType.getDescribedType()) {
24332435
options.add(FeatureSet::ReferenceTypes | FeatureSet::GC,
@@ -5054,8 +5056,8 @@ Expression* TranslateToFuzzReader::makeRefTest(Type type) {
50545056
switch (upTo(3)) {
50555057
case 0:
50565058
// Totally random.
5057-
refType = getReferenceType();
5058-
castType = getReferenceType();
5059+
refType = getCastableReferenceType();
5060+
castType = getCastableReferenceType();
50595061
// They must share a bottom type in order to validate.
50605062
if (refType.getHeapType().getBottom() ==
50615063
castType.getHeapType().getBottom()) {
@@ -5066,12 +5068,12 @@ Expression* TranslateToFuzzReader::makeRefTest(Type type) {
50665068
[[fallthrough]];
50675069
case 1:
50685070
// Cast is a subtype of ref.
5069-
refType = getReferenceType();
5071+
refType = getCastableReferenceType();
50705072
castType = getSubType(refType);
50715073
break;
50725074
case 2:
50735075
// Ref is a subtype of cast.
5074-
castType = getReferenceType();
5076+
castType = getCastableReferenceType();
50755077
refType = getSubType(castType);
50765078
break;
50775079
default:
@@ -5095,7 +5097,7 @@ Expression* TranslateToFuzzReader::makeRefCast(Type type) {
50955097
switch (upTo(3)) {
50965098
case 0:
50975099
// Totally random.
5098-
refType = getReferenceType();
5100+
refType = getCastableReferenceType();
50995101
// They must share a bottom type in order to validate.
51005102
if (refType.getHeapType().getBottom() == type.getHeapType().getBottom()) {
51015103
break;
@@ -5200,7 +5202,11 @@ Expression* TranslateToFuzzReader::makeBrOn(Type type) {
52005202
// We are sending a reference type to the target. All other BrOn variants can
52015203
// do that.
52025204
assert(targetType.isRef());
5203-
auto op = pick(BrOnNonNull, BrOnCast, BrOnCastFail);
5205+
// BrOnNonNull can handle sending any reference. The casts are more limited.
5206+
auto op = BrOnNonNull;
5207+
if (targetType.isCastable()) {
5208+
op = pick(BrOnNonNull, BrOnCast, BrOnCastFail);
5209+
}
52045210
Type castType = Type::none;
52055211
Type refType;
52065212
switch (op) {
@@ -5645,6 +5651,26 @@ Type TranslateToFuzzReader::getReferenceType() {
56455651
Type(HeapType::string, NonNullable)));
56465652
}
56475653

5654+
Type TranslateToFuzzReader::getCastableReferenceType() {
5655+
int tries = fuzzParams->TRIES;
5656+
while (tries-- > 0) {
5657+
auto type = getReferenceType();
5658+
if (type.isCastable()) {
5659+
return type;
5660+
}
5661+
}
5662+
// We failed to find a type using fair sampling. Do something simple that must
5663+
// work.
5664+
Type type;
5665+
if (oneIn(4)) {
5666+
type = getSubType(Type(HeapType::func, Nullable));
5667+
} else {
5668+
type = getSubType(Type(HeapType::any, Nullable));
5669+
}
5670+
assert(type.isCastable());
5671+
return type;
5672+
}
5673+
56485674
Type TranslateToFuzzReader::getEqReferenceType() {
56495675
if (oneIn(2) && !interestingHeapTypes.empty()) {
56505676
// Try to find an interesting eq-compatible type.

src/wasm-stack.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,11 @@ class BinaryInstWriter : public OverriddenVisitor<BinaryInstWriter> {
167167
// when they have a value that is more refined than the wasm type system
168168
// allows atm (and they are not dropped, in which case the type would not
169169
// matter). See https://github.com/WebAssembly/binaryen/pull/6390 for more on
170-
// the difference. As a result of the difference, we will insert extra casts
171-
// to ensure validation in the wasm spec. The wasm spec will hopefully improve
172-
// to use the more refined type as well, which would remove the need for this
173-
// hack.
174-
//
175-
// Each br_if present as a key here is mapped to the unrefined type for it.
176-
// That is, the br_if has a type in Binaryen IR that is too refined, and the
177-
// map contains the unrefined one (which we need to know the local types, as
178-
// we'll stash the unrefined values and then cast them).
179-
std::unordered_map<Break*, Type> brIfsNeedingHandling;
170+
// the difference. As a result of the difference, we must fix things up for
171+
// the spec. (The wasm spec might - hopefully - improve to use the more
172+
// refined type as well, which would remove the need for this hack, and
173+
// improve code size in general.)
174+
std::unordered_set<Break*> brIfsNeedingHandling;
180175
};
181176

182177
// Takes binaryen IR and converts it to something else (binary or stack IR)

src/wasm-type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ class HeapType {
184184
return isBasic() && getBasic(Unshared) == type;
185185
}
186186

187+
bool isCastable();
188+
187189
Signature getSignature() const;
188190
Continuation getContinuation() const;
189191

@@ -415,6 +417,7 @@ class Type {
415417
return isRef() && getHeapType().isContinuation();
416418
}
417419
bool isDefaultable() const;
420+
bool isCastable();
418421

419422
// TODO: Allow this only for reference types.
420423
Nullability getNullability() const {

src/wasm/wasm-stack.cpp

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -63,56 +63,88 @@ void BinaryInstWriter::visitLoop(Loop* curr) {
6363
}
6464

6565
void BinaryInstWriter::visitBreak(Break* curr) {
66+
auto type = curr->type;
67+
68+
// See comment on |brIfsNeedingHandling| for the extra handling we need to
69+
// emit here for certain br_ifs. If we need that handling, we either use a
70+
// cast in simple cases, or scratch locals otherwise. We use the scratch
71+
// locals to stash the stack before the br_if (which contains the refined
72+
// types), then restore it later from those locals.
73+
bool needScratchLocals = false;
74+
// If we need locals, we must track how many we've used from each type as we
75+
// go, as a type might appear multiple times in the tuple. We know we have
76+
// enough of a range allocated for them, so we just increment as we go.
77+
std::unordered_map<Type, Index> scratchTypeUses;
78+
// Logic to stash and restore the stack, given a vector of types we are
79+
// stashing/restoring. We will first stash the entire stack, including the i32
80+
// condition, and after the br_if, restore the value (without the condition).
81+
auto stashStack = [&](const std::vector<Type>& types) {
82+
for (Index i = 0; i < types.size(); i++) {
83+
auto t = types[types.size() - i - 1];
84+
assert(scratchLocals.find(t) != scratchLocals.end());
85+
auto localIndex = scratchLocals[t] + scratchTypeUses[t]++;
86+
o << int8_t(BinaryConsts::LocalSet) << U32LEB(localIndex);
87+
}
88+
};
89+
auto restoreStack = [&](const std::vector<Type>& types) {
90+
// Use a copy of this data, as we will restore twice.
91+
auto currScratchTypeUses = scratchTypeUses;
92+
for (Index i = 0; i < types.size(); i++) {
93+
auto t = types[i];
94+
auto localIndex = scratchLocals[t] + --currScratchTypeUses[t];
95+
o << int8_t(BinaryConsts::LocalGet) << U32LEB(localIndex);
96+
}
97+
};
98+
99+
// The types on the stack before the br_if. We need this if we use locals to
100+
// stash the stack.
101+
std::vector<Type> typesOnStack;
102+
103+
auto needHandling = brIfsNeedingHandling.count(curr);
104+
if (needHandling) {
105+
// Tuples always need scratch locals. Uncastable types do as well, we we
106+
// can't fix them up below with a simple cast.
107+
needScratchLocals = type.isTuple() || !type.isCastable();
108+
if (needScratchLocals) {
109+
// Stash all the values on the stack to those locals, then reload them for
110+
// the br_if to consume. Later, we can reload the refined values after the
111+
// br_if, for its parent to consume.
112+
113+
typesOnStack = std::vector<Type>(type.begin(), type.end());
114+
typesOnStack.push_back(Type::i32);
115+
116+
stashStack(typesOnStack);
117+
restoreStack(typesOnStack);
118+
// The stack is now in the same state as before, but we have copies in
119+
// locals for later.
120+
}
121+
}
122+
66123
o << int8_t(curr->condition ? BinaryConsts::BrIf : BinaryConsts::Br)
67124
<< U32LEB(getBreakIndex(curr->name));
68125

69-
// See comment on |brIfsNeedingHandling| for the extra casts we need to emit
70-
// here for certain br_ifs.
71-
auto iter = brIfsNeedingHandling.find(curr);
72-
if (iter != brIfsNeedingHandling.end()) {
73-
auto unrefinedType = iter->second;
74-
auto type = curr->type;
75-
assert(type.size() == unrefinedType.size());
126+
if (needHandling) {
127+
if (!needScratchLocals) {
128+
// We can just cast here, avoiding scratch locals. (Casting adds overhead,
129+
// but this is very rare, and it avoids adding locals, which would keep
130+
// growing the wasm with each roundtrip.)
76131

77-
assert(curr->type.hasRef());
78-
79-
auto emitCast = [&](Type to) {
80132
// Shim a tiny bit of IR, just enough to get visitRefCast to see what we
81133
// are casting, and to emit the proper thing.
82134
RefCast cast;
83-
cast.type = to;
135+
cast.type = type;
84136
cast.ref = cast.desc = nullptr;
85137
visitRefCast(&cast);
86-
};
87-
88-
if (!type.isTuple()) {
89-
// Simple: Just emit a cast, and then the type matches Binaryen IR's.
90-
emitCast(type);
91138
} else {
92-
// Tuples are trickier to handle, and we need to use scratch locals. Stash
93-
// all the values on the stack to those locals, then reload them, casting
94-
// as we go.
95-
//
96-
// We must track how many scratch locals we've used from each type as we
97-
// go, as a type might appear multiple times in the tuple. We allocated
98-
// enough for each, in a contiguous range, so we just increment as we go.
99-
std::unordered_map<Type, Index> scratchTypeUses;
100-
for (Index i = 0; i < unrefinedType.size(); i++) {
101-
auto t = unrefinedType[unrefinedType.size() - i - 1];
102-
assert(scratchLocals.find(t) != scratchLocals.end());
103-
auto localIndex = scratchLocals[t] + scratchTypeUses[t]++;
104-
o << int8_t(BinaryConsts::LocalSet) << U32LEB(localIndex);
105-
}
106-
for (Index i = 0; i < unrefinedType.size(); i++) {
107-
auto t = unrefinedType[i];
108-
auto localIndex = scratchLocals[t] + --scratchTypeUses[t];
109-
o << int8_t(BinaryConsts::LocalGet) << U32LEB(localIndex);
110-
if (t.isRef()) {
111-
// Note that we cast all types here, when perhaps only some of the
112-
// tuple's lanes need that. This is simpler.
113-
emitCast(type[i]);
114-
}
139+
// We need locals. Earlier we stashed the stack, so we just need to
140+
// restore the value from there (note we don't restore the condition),
141+
// after dropping the br_if's unrefined values.
142+
for (Index i = 0; i < type.size(); ++i) {
143+
o << int8_t(BinaryConsts::Drop);
115144
}
145+
assert(typesOnStack.back() == Type::i32);
146+
typesOnStack.pop_back();
147+
restoreStack(typesOnStack);
116148
}
117149
}
118150
}
@@ -3094,8 +3126,9 @@ InsertOrderedMap<Type, Index> BinaryInstWriter::countScratchLocals() {
30943126
: writer(writer), finder(finder) {}
30953127

30963128
void visitBreak(Break* curr) {
3129+
auto type = curr->type;
30973130
// See if this is one of the dangerous br_ifs we must handle.
3098-
if (!curr->type.hasRef()) {
3131+
if (!type.hasRef()) {
30993132
// Not even a reference.
31003133
return;
31013134
}
@@ -3106,7 +3139,7 @@ InsertOrderedMap<Type, Index> BinaryInstWriter::countScratchLocals() {
31063139
return;
31073140
}
31083141
if (auto* cast = parent->dynCast<RefCast>()) {
3109-
if (Type::isSubType(cast->type, curr->type)) {
3142+
if (Type::isSubType(cast->type, type)) {
31103143
// It is cast to the same type or a better one. In particular this
31113144
// handles the case of repeated roundtripping: After the first
31123145
// roundtrip we emit a cast that we'll identify here, and not emit
@@ -3117,23 +3150,30 @@ InsertOrderedMap<Type, Index> BinaryInstWriter::countScratchLocals() {
31173150
}
31183151
auto* breakTarget = findBreakTarget(curr->name);
31193152
auto unrefinedType = breakTarget->type;
3120-
if (unrefinedType == curr->type) {
3153+
if (unrefinedType == type) {
31213154
// It has the proper type anyhow.
31223155
return;
31233156
}
31243157

31253158
// Mark the br_if as needing handling, and add the type to the set of
31263159
// types we need scratch tuple locals for (if relevant).
3127-
writer.brIfsNeedingHandling[curr] = unrefinedType;
3128-
3129-
if (unrefinedType.isTuple()) {
3130-
// We must allocate enough scratch locals for this tuple. Note that we
3131-
// may need more than one per type in the tuple, if a type appears more
3132-
// than once, so we count their appearances.
3160+
writer.brIfsNeedingHandling.insert(curr);
3161+
3162+
// Simple cases can be handled by a cast. However, tuples and uncastable
3163+
// types require us to use locals too.
3164+
if (type.isTuple() || !type.isCastable()) {
3165+
// We must allocate enough scratch locals for this tuple, plus the i32
3166+
// of the condition, as we will stash it all so that we can restore the
3167+
// fully refined value after the br_if.
3168+
//
3169+
// Note that we may need more than one per type in the tuple, if a type
3170+
// appears more than once, so we count their appearances.
31333171
InsertOrderedMap<Type, Index> scratchTypeUses;
3134-
for (auto t : unrefinedType) {
3172+
for (auto t : type) {
31353173
scratchTypeUses[t]++;
31363174
}
3175+
// The condition.
3176+
scratchTypeUses[Type::i32]++;
31373177
for (auto& [type, uses] : scratchTypeUses) {
31383178
auto& count = finder.scratches[type];
31393179
count = std::max(count, uses);

src/wasm/wasm-type.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,8 @@ bool Type::isDefaultable() const {
623623
return isConcrete() && !isNonNullable();
624624
}
625625

626+
bool Type::isCastable() { return isRef() && getHeapType().isCastable(); }
627+
626628
unsigned Type::getByteSize() const {
627629
// TODO: alignment?
628630
auto getSingleByteSize = [](Type t) {
@@ -889,6 +891,11 @@ Shareability HeapType::getShared() const {
889891
}
890892
}
891893

894+
bool HeapType::isCastable() {
895+
return !isContinuation() && !isMaybeShared(HeapType::cont) &&
896+
!isMaybeShared(HeapType::nocont);
897+
}
898+
892899
Signature HeapType::getSignature() const {
893900
assert(isSignature());
894901
return getHeapTypeInfo(*this)->signature;

0 commit comments

Comments
 (0)