Skip to content

Commit 2d47c0b

Browse files
authored
Implement more cases for getMaxBits (#2879)
- Complete 64-bit cases in range `AddInt64` ... `ShrSInt64` - `ExtendSInt32` and `ExtendUInt32` for unary cases - For binary cases - `AddInt32` / `AddInt64` - `MulInt32` / `MulInt64` - `RemUInt32` / `RemUInt64` - `RemSInt32` / `RemSInt64` - `DivUInt32` / `DivUInt64` - `DivSInt32` / `DivSInt64` - and more Also more fast paths for some getMaxBits calculations
1 parent 6116553 commit 2d47c0b

File tree

6 files changed

+719
-40
lines changed

6 files changed

+719
-40
lines changed

src/ir/bits.h

Lines changed: 172 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,35 +128,85 @@ struct DummyLocalInfoProvider {
128128
template<typename LocalInfoProvider = DummyLocalInfoProvider>
129129
Index getMaxBits(Expression* curr,
130130
LocalInfoProvider* localInfoProvider = nullptr) {
131-
if (auto* const_ = curr->dynCast<Const>()) {
131+
if (auto* c = curr->dynCast<Const>()) {
132132
switch (curr->type.getBasic()) {
133133
case Type::i32:
134-
return 32 - const_->value.countLeadingZeroes().geti32();
134+
return 32 - c->value.countLeadingZeroes().geti32();
135135
case Type::i64:
136-
return 64 - const_->value.countLeadingZeroes().geti64();
136+
return 64 - c->value.countLeadingZeroes().geti64();
137137
default:
138138
WASM_UNREACHABLE("invalid type");
139139
}
140140
} else if (auto* binary = curr->dynCast<Binary>()) {
141141
switch (binary->op) {
142142
// 32-bit
143-
case AddInt32:
144-
case SubInt32:
145-
case MulInt32:
146-
case DivSInt32:
147-
case DivUInt32:
148-
case RemSInt32:
149-
case RemUInt32:
150143
case RotLInt32:
151144
case RotRInt32:
145+
case SubInt32:
146+
return 32;
147+
case AddInt32: {
148+
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
149+
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
150+
return std::min(Index(32), std::max(maxBitsLeft, maxBitsRight) + 1);
151+
}
152+
case MulInt32: {
153+
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
154+
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
155+
return std::min(Index(32), maxBitsLeft + maxBitsRight);
156+
}
157+
case DivSInt32: {
158+
if (auto* c = binary->right->dynCast<Const>()) {
159+
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
160+
// If either side might be negative, then the result will be negative
161+
if (maxBitsLeft == 32 || c->value.geti32() < 0) {
162+
return 32;
163+
}
164+
int32_t bitsRight = getMaxBits(c);
165+
return std::max(0, maxBitsLeft - bitsRight + 1);
166+
}
167+
return 32;
168+
}
169+
case DivUInt32: {
170+
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
171+
if (auto* c = binary->right->dynCast<Const>()) {
172+
int32_t bitsRight = getMaxBits(c);
173+
return std::max(0, maxBitsLeft - bitsRight + 1);
174+
}
175+
return maxBitsLeft;
176+
}
177+
case RemSInt32: {
178+
if (auto* c = binary->right->dynCast<Const>()) {
179+
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
180+
// if maxBitsLeft is negative
181+
if (maxBitsLeft == 32) {
182+
return 32;
183+
}
184+
auto bitsRight = Index(CeilLog2(c->value.geti32()));
185+
return std::min(maxBitsLeft, bitsRight);
186+
}
187+
return 32;
188+
}
189+
case RemUInt32: {
190+
if (auto* c = binary->right->dynCast<Const>()) {
191+
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
192+
auto bitsRight = Index(CeilLog2(c->value.geti32()));
193+
return std::min(maxBitsLeft, bitsRight);
194+
}
152195
return 32;
153-
case AndInt32:
196+
}
197+
case AndInt32: {
154198
return std::min(getMaxBits(binary->left, localInfoProvider),
155199
getMaxBits(binary->right, localInfoProvider));
200+
}
156201
case OrInt32:
157-
case XorInt32:
158-
return std::max(getMaxBits(binary->left, localInfoProvider),
159-
getMaxBits(binary->right, localInfoProvider));
202+
case XorInt32: {
203+
auto maxBits = getMaxBits(binary->right, localInfoProvider);
204+
// if maxBits is negative
205+
if (maxBits == 32) {
206+
return 32;
207+
}
208+
return std::max(getMaxBits(binary->left, localInfoProvider), maxBits);
209+
}
160210
case ShlInt32: {
161211
if (auto* shifts = binary->right->dynCast<Const>()) {
162212
return std::min(Index(32),
@@ -178,6 +228,7 @@ Index getMaxBits(Expression* curr,
178228
case ShrSInt32: {
179229
if (auto* shift = binary->right->dynCast<Const>()) {
180230
auto maxBits = getMaxBits(binary->left, localInfoProvider);
231+
// if maxBits is negative
181232
if (maxBits == 32) {
182233
return 32;
183234
}
@@ -188,7 +239,105 @@ Index getMaxBits(Expression* curr,
188239
}
189240
return 32;
190241
}
191-
// 64-bit TODO
242+
case RotLInt64:
243+
case RotRInt64:
244+
case SubInt64:
245+
return 64;
246+
case AddInt64: {
247+
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
248+
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
249+
return std::min(Index(64), std::max(maxBitsLeft, maxBitsRight) + 1);
250+
}
251+
case MulInt64: {
252+
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
253+
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
254+
return std::min(Index(64), maxBitsLeft + maxBitsRight);
255+
}
256+
case DivSInt64: {
257+
if (auto* c = binary->right->dynCast<Const>()) {
258+
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
259+
// if maxBitsLeft or right const value is negative
260+
if (maxBitsLeft == 64 || c->value.geti64() < 0) {
261+
return 64;
262+
}
263+
int32_t bitsRight = getMaxBits(c);
264+
return std::max(0, maxBitsLeft - bitsRight + 1);
265+
}
266+
return 64;
267+
}
268+
case DivUInt64: {
269+
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
270+
if (auto* c = binary->right->dynCast<Const>()) {
271+
int32_t bitsRight = getMaxBits(c);
272+
return std::max(0, maxBitsLeft - bitsRight + 1);
273+
}
274+
return maxBitsLeft;
275+
}
276+
case RemSInt64: {
277+
if (auto* c = binary->right->dynCast<Const>()) {
278+
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
279+
// if maxBitsLeft is negative
280+
if (maxBitsLeft == 64) {
281+
return 64;
282+
}
283+
auto bitsRight = Index(CeilLog2(c->value.geti64()));
284+
return std::min(maxBitsLeft, bitsRight);
285+
}
286+
return 64;
287+
}
288+
case RemUInt64: {
289+
if (auto* c = binary->right->dynCast<Const>()) {
290+
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
291+
auto bitsRight = Index(CeilLog2(c->value.geti64()));
292+
return std::min(maxBitsLeft, bitsRight);
293+
}
294+
return 64;
295+
}
296+
case AndInt64: {
297+
auto maxBits = getMaxBits(binary->right, localInfoProvider);
298+
return std::min(getMaxBits(binary->left, localInfoProvider), maxBits);
299+
}
300+
case OrInt64:
301+
case XorInt64: {
302+
auto maxBits = getMaxBits(binary->right, localInfoProvider);
303+
// if maxBits is negative
304+
if (maxBits == 64) {
305+
return 64;
306+
}
307+
return std::max(getMaxBits(binary->left, localInfoProvider), maxBits);
308+
}
309+
case ShlInt64: {
310+
if (auto* shifts = binary->right->dynCast<Const>()) {
311+
auto maxBits = getMaxBits(binary->left, localInfoProvider);
312+
return std::min(Index(64),
313+
Bits::getEffectiveShifts(shifts) + maxBits);
314+
}
315+
return 64;
316+
}
317+
case ShrUInt64: {
318+
if (auto* shift = binary->right->dynCast<Const>()) {
319+
auto maxBits = getMaxBits(binary->left, localInfoProvider);
320+
auto shifts =
321+
std::min(Index(Bits::getEffectiveShifts(shift)),
322+
maxBits); // can ignore more shifts than zero us out
323+
return std::max(Index(0), maxBits - shifts);
324+
}
325+
return 64;
326+
}
327+
case ShrSInt64: {
328+
if (auto* shift = binary->right->dynCast<Const>()) {
329+
auto maxBits = getMaxBits(binary->left, localInfoProvider);
330+
// if maxBits is negative
331+
if (maxBits == 64) {
332+
return 64;
333+
}
334+
auto shifts =
335+
std::min(Index(Bits::getEffectiveShifts(shift)),
336+
maxBits); // can ignore more shifts than zero us out
337+
return std::max(Index(0), maxBits - shifts);
338+
}
339+
return 64;
340+
}
192341
// comparisons
193342
case EqInt32:
194343
case NeInt32:
@@ -200,6 +349,7 @@ Index getMaxBits(Expression* curr,
200349
case GtUInt32:
201350
case GeSInt32:
202351
case GeUInt32:
352+
203353
case EqInt64:
204354
case NeInt64:
205355
case LtSInt64:
@@ -210,12 +360,14 @@ Index getMaxBits(Expression* curr,
210360
case GtUInt64:
211361
case GeSInt64:
212362
case GeUInt64:
363+
213364
case EqFloat32:
214365
case NeFloat32:
215366
case LtFloat32:
216367
case LeFloat32:
217368
case GtFloat32:
218369
case GeFloat32:
370+
219371
case EqFloat64:
220372
case NeFloat64:
221373
case LtFloat64:
@@ -240,7 +392,12 @@ Index getMaxBits(Expression* curr,
240392
case EqZInt64:
241393
return 1;
242394
case WrapInt64:
395+
case ExtendUInt32:
243396
return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
397+
case ExtendSInt32: {
398+
auto maxBits = getMaxBits(unary->value, localInfoProvider);
399+
return maxBits == 32 ? Index(64) : maxBits;
400+
}
244401
default: {
245402
}
246403
}

src/support/bits.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ template<> int CountLeadingZeroes<uint64_t>(uint64_t v) {
152152
#endif
153153
}
154154

155+
template<> int CeilLog2<uint32_t>(uint32_t v) {
156+
return 32 - CountLeadingZeroes(v - 1);
157+
}
158+
159+
template<> int CeilLog2<uint64_t>(uint64_t v) {
160+
return 64 - CountLeadingZeroes(v - 1);
161+
}
162+
155163
uint32_t Log2(uint32_t v) {
156164
switch (v) {
157165
default:

src/support/bits.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ template<typename T> int PopCount(T);
4040
template<typename T> uint32_t BitReverse(T);
4141
template<typename T> int CountTrailingZeroes(T);
4242
template<typename T> int CountLeadingZeroes(T);
43+
template<typename T> int CeilLog2(T);
4344

4445
#ifndef wasm_support_bits_definitions
4546
// The template specializations are provided elsewhere.
@@ -52,6 +53,8 @@ extern template int CountTrailingZeroes(uint32_t);
5253
extern template int CountTrailingZeroes(uint64_t);
5354
extern template int CountLeadingZeroes(uint32_t);
5455
extern template int CountLeadingZeroes(uint64_t);
56+
extern template int CeilLog2(uint32_t);
57+
extern template int CeilLog2(uint64_t);
5558
#endif
5659

5760
// Convenience signed -> unsigned. It usually doesn't make much sense to use bit
@@ -65,6 +68,9 @@ template<typename T> int CountTrailingZeroes(T v) {
6568
template<typename T> int CountLeadingZeroes(T v) {
6669
return CountLeadingZeroes(typename std::make_unsigned<T>::type(v));
6770
}
71+
template<typename T> int CeilLog2(T v) {
72+
return CeilLog2(typename std::make_unsigned<T>::type(v));
73+
}
6874
template<typename T> bool IsPowerOf2(T v) {
6975
return v != 0 && (v & (v - 1)) == 0;
7076
}

0 commit comments

Comments
 (0)