-
Notifications
You must be signed in to change notification settings - Fork 787
Implement more cases for getMaxBits #2879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ee08839
374c9b6
3a0f72e
091e15e
18a2aa2
168fed4
d04ae7d
c08376c
2dc792c
56146d3
3f68768
36f73e5
c77d5d4
d60b53e
0bba5b6
f43c8db
a6415e1
d979620
6685e6d
7726f18
a39091a
908eb96
df767f4
73ad57d
6b8d232
f387148
38ee4e7
74f10d5
fa13337
2e9ab76
3ea4530
252ce01
fd4e8cf
60b7b55
d93f65d
c9fe2cb
14ae90c
90f39d4
5a99214
64aded9
4b2061a
216b6fa
c631b94
d0a4938
543d239
5673101
ae827a8
8567487
957e945
d5d6d2b
709b8ed
569362f
9545f9e
9085d8e
422e46f
551a675
d7335a8
3a27ca3
e0739c3
74e568c
ae75346
6883ce7
0cc2e85
acd6518
df40a08
da177ae
d84f9ba
5c1d1c4
f319270
55d9824
f924008
790a05c
30b1842
341a3ef
b91ee8c
719d8c3
146c966
a74f10e
b4d614d
9b2fe4d
162c94d
8404d91
4be0956
ec62420
5fc687a
0be9076
33c9eab
277d159
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,35 +128,85 @@ struct DummyLocalInfoProvider { | |
template<typename LocalInfoProvider = DummyLocalInfoProvider> | ||
Index getMaxBits(Expression* curr, | ||
LocalInfoProvider* localInfoProvider = nullptr) { | ||
if (auto* const_ = curr->dynCast<Const>()) { | ||
if (auto* c = curr->dynCast<Const>()) { | ||
switch (curr->type.getBasic()) { | ||
case Type::i32: | ||
return 32 - const_->value.countLeadingZeroes().geti32(); | ||
return 32 - c->value.countLeadingZeroes().geti32(); | ||
case Type::i64: | ||
return 64 - const_->value.countLeadingZeroes().geti64(); | ||
return 64 - c->value.countLeadingZeroes().geti64(); | ||
default: | ||
WASM_UNREACHABLE("invalid type"); | ||
} | ||
} else if (auto* binary = curr->dynCast<Binary>()) { | ||
switch (binary->op) { | ||
// 32-bit | ||
case AddInt32: | ||
case SubInt32: | ||
case MulInt32: | ||
case DivSInt32: | ||
case DivUInt32: | ||
case RemSInt32: | ||
case RemUInt32: | ||
case RotLInt32: | ||
case RotRInt32: | ||
case SubInt32: | ||
return 32; | ||
case AddInt32: { | ||
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); | ||
return std::min(Index(32), std::max(maxBitsLeft, maxBitsRight) + 1); | ||
} | ||
case MulInt32: { | ||
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); | ||
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
return std::min(Index(32), maxBitsLeft + maxBitsRight); | ||
} | ||
case DivSInt32: { | ||
if (auto* c = binary->right->dynCast<Const>()) { | ||
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
// If either side might be negative, then the result will be negative | ||
if (maxBitsLeft == 32 || c->value.geti32() < 0) { | ||
return 32; | ||
} | ||
int32_t bitsRight = getMaxBits(c); | ||
return std::max(0, maxBitsLeft - bitsRight + 1); | ||
} | ||
return 32; | ||
} | ||
case DivUInt32: { | ||
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
if (auto* c = binary->right->dynCast<Const>()) { | ||
int32_t bitsRight = getMaxBits(c); | ||
return std::max(0, maxBitsLeft - bitsRight + 1); | ||
} | ||
return maxBitsLeft; | ||
} | ||
case RemSInt32: { | ||
if (auto* c = binary->right->dynCast<Const>()) { | ||
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
// if maxBitsLeft is negative | ||
if (maxBitsLeft == 32) { | ||
return 32; | ||
} | ||
auto bitsRight = Index(CeilLog2(c->value.geti32())); | ||
return std::min(maxBitsLeft, bitsRight); | ||
} | ||
return 32; | ||
} | ||
case RemUInt32: { | ||
if (auto* c = binary->right->dynCast<Const>()) { | ||
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
auto bitsRight = Index(CeilLog2(c->value.geti32())); | ||
return std::min(maxBitsLeft, bitsRight); | ||
} | ||
return 32; | ||
case AndInt32: | ||
} | ||
case AndInt32: { | ||
return std::min(getMaxBits(binary->left, localInfoProvider), | ||
getMaxBits(binary->right, localInfoProvider)); | ||
} | ||
case OrInt32: | ||
case XorInt32: | ||
return std::max(getMaxBits(binary->left, localInfoProvider), | ||
getMaxBits(binary->right, localInfoProvider)); | ||
case XorInt32: { | ||
auto maxBits = getMaxBits(binary->right, localInfoProvider); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there reason to believe that checking the right side first will save more work than checking the left side first, or is it an arbitrary choice? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, right side could be cheaper due to canonization which always force constant on the right. Also |
||
// if maxBits is negative | ||
if (maxBits == 32) { | ||
return 32; | ||
} | ||
return std::max(getMaxBits(binary->left, localInfoProvider), maxBits); | ||
} | ||
case ShlInt32: { | ||
if (auto* shifts = binary->right->dynCast<Const>()) { | ||
return std::min(Index(32), | ||
|
@@ -178,6 +228,7 @@ Index getMaxBits(Expression* curr, | |
case ShrSInt32: { | ||
if (auto* shift = binary->right->dynCast<Const>()) { | ||
auto maxBits = getMaxBits(binary->left, localInfoProvider); | ||
// if maxBits is negative | ||
if (maxBits == 32) { | ||
return 32; | ||
} | ||
|
@@ -188,7 +239,105 @@ Index getMaxBits(Expression* curr, | |
} | ||
return 32; | ||
} | ||
// 64-bit TODO | ||
case RotLInt64: | ||
case RotRInt64: | ||
case SubInt64: | ||
return 64; | ||
case AddInt64: { | ||
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); | ||
return std::min(Index(64), std::max(maxBitsLeft, maxBitsRight) + 1); | ||
MaxGraey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
case MulInt64: { | ||
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); | ||
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
return std::min(Index(64), maxBitsLeft + maxBitsRight); | ||
} | ||
case DivSInt64: { | ||
if (auto* c = binary->right->dynCast<Const>()) { | ||
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
// if maxBitsLeft or right const value is negative | ||
if (maxBitsLeft == 64 || c->value.geti64() < 0) { | ||
return 64; | ||
} | ||
int32_t bitsRight = getMaxBits(c); | ||
return std::max(0, maxBitsLeft - bitsRight + 1); | ||
} | ||
return 64; | ||
} | ||
case DivUInt64: { | ||
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
if (auto* c = binary->right->dynCast<Const>()) { | ||
int32_t bitsRight = getMaxBits(c); | ||
return std::max(0, maxBitsLeft - bitsRight + 1); | ||
} | ||
return maxBitsLeft; | ||
} | ||
case RemSInt64: { | ||
if (auto* c = binary->right->dynCast<Const>()) { | ||
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
// if maxBitsLeft is negative | ||
if (maxBitsLeft == 64) { | ||
return 64; | ||
} | ||
auto bitsRight = Index(CeilLog2(c->value.geti64())); | ||
return std::min(maxBitsLeft, bitsRight); | ||
} | ||
return 64; | ||
} | ||
case RemUInt64: { | ||
if (auto* c = binary->right->dynCast<Const>()) { | ||
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); | ||
auto bitsRight = Index(CeilLog2(c->value.geti64())); | ||
return std::min(maxBitsLeft, bitsRight); | ||
} | ||
return 64; | ||
} | ||
case AndInt64: { | ||
auto maxBits = getMaxBits(binary->right, localInfoProvider); | ||
return std::min(getMaxBits(binary->left, localInfoProvider), maxBits); | ||
} | ||
case OrInt64: | ||
case XorInt64: { | ||
auto maxBits = getMaxBits(binary->right, localInfoProvider); | ||
// if maxBits is negative | ||
if (maxBits == 64) { | ||
return 64; | ||
} | ||
return std::max(getMaxBits(binary->left, localInfoProvider), maxBits); | ||
} | ||
case ShlInt64: { | ||
if (auto* shifts = binary->right->dynCast<Const>()) { | ||
auto maxBits = getMaxBits(binary->left, localInfoProvider); | ||
return std::min(Index(64), | ||
Bits::getEffectiveShifts(shifts) + maxBits); | ||
} | ||
return 64; | ||
} | ||
case ShrUInt64: { | ||
if (auto* shift = binary->right->dynCast<Const>()) { | ||
auto maxBits = getMaxBits(binary->left, localInfoProvider); | ||
auto shifts = | ||
std::min(Index(Bits::getEffectiveShifts(shift)), | ||
maxBits); // can ignore more shifts than zero us out | ||
return std::max(Index(0), maxBits - shifts); | ||
} | ||
return 64; | ||
} | ||
case ShrSInt64: { | ||
if (auto* shift = binary->right->dynCast<Const>()) { | ||
auto maxBits = getMaxBits(binary->left, localInfoProvider); | ||
// if maxBits is negative | ||
if (maxBits == 64) { | ||
return 64; | ||
} | ||
auto shifts = | ||
std::min(Index(Bits::getEffectiveShifts(shift)), | ||
maxBits); // can ignore more shifts than zero us out | ||
return std::max(Index(0), maxBits - shifts); | ||
} | ||
return 64; | ||
} | ||
// comparisons | ||
case EqInt32: | ||
case NeInt32: | ||
|
@@ -200,6 +349,7 @@ Index getMaxBits(Expression* curr, | |
case GtUInt32: | ||
case GeSInt32: | ||
case GeUInt32: | ||
|
||
case EqInt64: | ||
case NeInt64: | ||
case LtSInt64: | ||
|
@@ -210,12 +360,14 @@ Index getMaxBits(Expression* curr, | |
case GtUInt64: | ||
case GeSInt64: | ||
case GeUInt64: | ||
|
||
case EqFloat32: | ||
case NeFloat32: | ||
case LtFloat32: | ||
case LeFloat32: | ||
case GtFloat32: | ||
case GeFloat32: | ||
|
||
case EqFloat64: | ||
case NeFloat64: | ||
case LtFloat64: | ||
|
@@ -240,7 +392,12 @@ Index getMaxBits(Expression* curr, | |
case EqZInt64: | ||
return 1; | ||
case WrapInt64: | ||
case ExtendUInt32: | ||
return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); | ||
case ExtendSInt32: { | ||
auto maxBits = getMaxBits(unary->value, localInfoProvider); | ||
return maxBits == 32 ? Index(64) : maxBits; | ||
} | ||
default: { | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ template<typename T> int PopCount(T); | |
template<typename T> uint32_t BitReverse(T); | ||
template<typename T> int CountTrailingZeroes(T); | ||
template<typename T> int CountLeadingZeroes(T); | ||
template<typename T> int CeilLog2(T); | ||
|
||
#ifndef wasm_support_bits_definitions | ||
// The template specializations are provided elsewhere. | ||
|
@@ -52,6 +53,8 @@ extern template int CountTrailingZeroes(uint32_t); | |
extern template int CountTrailingZeroes(uint64_t); | ||
extern template int CountLeadingZeroes(uint32_t); | ||
extern template int CountLeadingZeroes(uint64_t); | ||
extern template int CeilLog2(uint32_t); | ||
extern template int CeilLog2(uint64_t); | ||
Comment on lines
+56
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like it would be a lot simpler to just overload the CeilLog2 for uint32_t and uint64_t without using any templates. That probably applies to all these functions, but since they're already like that, this can be left for a follow up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was my first attempt, but templates in C++ still a mystery to me sometimes=) I didn’t manage to make friends with template<typename T>
int CeilLog2(T v) {
return sizeof(T) * 8 - CountLeadingZeroes(v - 1);
} Also it quite hard to restrict There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, no problem, let's leave cleaning that up to a follow-up PR. |
||
#endif | ||
|
||
// Convenience signed -> unsigned. It usually doesn't make much sense to use bit | ||
|
@@ -65,6 +68,9 @@ template<typename T> int CountTrailingZeroes(T v) { | |
template<typename T> int CountLeadingZeroes(T v) { | ||
return CountLeadingZeroes(typename std::make_unsigned<T>::type(v)); | ||
} | ||
template<typename T> int CeilLog2(T v) { | ||
return CeilLog2(typename std::make_unsigned<T>::type(v)); | ||
} | ||
template<typename T> bool IsPowerOf2(T v) { | ||
return v != 0 && (v & (v - 1)) == 0; | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.