@@ -128,35 +128,85 @@ struct DummyLocalInfoProvider {
128
128
template <typename LocalInfoProvider = DummyLocalInfoProvider>
129
129
Index getMaxBits (Expression* curr,
130
130
LocalInfoProvider* localInfoProvider = nullptr ) {
131
- if (auto * const_ = curr->dynCast <Const>()) {
131
+ if (auto * c = curr->dynCast <Const>()) {
132
132
switch (curr->type .getBasic ()) {
133
133
case Type::i32 :
134
- return 32 - const_ ->value .countLeadingZeroes ().geti32 ();
134
+ return 32 - c ->value .countLeadingZeroes ().geti32 ();
135
135
case Type::i64 :
136
- return 64 - const_ ->value .countLeadingZeroes ().geti64 ();
136
+ return 64 - c ->value .countLeadingZeroes ().geti64 ();
137
137
default :
138
138
WASM_UNREACHABLE (" invalid type" );
139
139
}
140
140
} else if (auto * binary = curr->dynCast <Binary>()) {
141
141
switch (binary->op ) {
142
142
// 32-bit
143
- case AddInt32:
144
- case SubInt32:
145
- case MulInt32:
146
- case DivSInt32:
147
- case DivUInt32:
148
- case RemSInt32:
149
- case RemUInt32:
150
143
case RotLInt32:
151
144
case RotRInt32:
145
+ case SubInt32:
146
+ return 32 ;
147
+ case AddInt32: {
148
+ auto maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
149
+ auto maxBitsRight = getMaxBits (binary->right , localInfoProvider);
150
+ return std::min (Index (32 ), std::max (maxBitsLeft, maxBitsRight) + 1 );
151
+ }
152
+ case MulInt32: {
153
+ auto maxBitsRight = getMaxBits (binary->right , localInfoProvider);
154
+ auto maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
155
+ return std::min (Index (32 ), maxBitsLeft + maxBitsRight);
156
+ }
157
+ case DivSInt32: {
158
+ if (auto * c = binary->right ->dynCast <Const>()) {
159
+ int32_t maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
160
+ // If either side might be negative, then the result will be negative
161
+ if (maxBitsLeft == 32 || c->value .geti32 () < 0 ) {
162
+ return 32 ;
163
+ }
164
+ int32_t bitsRight = getMaxBits (c);
165
+ return std::max (0 , maxBitsLeft - bitsRight + 1 );
166
+ }
167
+ return 32 ;
168
+ }
169
+ case DivUInt32: {
170
+ int32_t maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
171
+ if (auto * c = binary->right ->dynCast <Const>()) {
172
+ int32_t bitsRight = getMaxBits (c);
173
+ return std::max (0 , maxBitsLeft - bitsRight + 1 );
174
+ }
175
+ return maxBitsLeft;
176
+ }
177
+ case RemSInt32: {
178
+ if (auto * c = binary->right ->dynCast <Const>()) {
179
+ auto maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
180
+ // if maxBitsLeft is negative
181
+ if (maxBitsLeft == 32 ) {
182
+ return 32 ;
183
+ }
184
+ auto bitsRight = Index (CeilLog2 (c->value .geti32 ()));
185
+ return std::min (maxBitsLeft, bitsRight);
186
+ }
187
+ return 32 ;
188
+ }
189
+ case RemUInt32: {
190
+ if (auto * c = binary->right ->dynCast <Const>()) {
191
+ auto maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
192
+ auto bitsRight = Index (CeilLog2 (c->value .geti32 ()));
193
+ return std::min (maxBitsLeft, bitsRight);
194
+ }
152
195
return 32 ;
153
- case AndInt32:
196
+ }
197
+ case AndInt32: {
154
198
return std::min (getMaxBits (binary->left , localInfoProvider),
155
199
getMaxBits (binary->right , localInfoProvider));
200
+ }
156
201
case OrInt32:
157
- case XorInt32:
158
- return std::max (getMaxBits (binary->left , localInfoProvider),
159
- getMaxBits (binary->right , localInfoProvider));
202
+ case XorInt32: {
203
+ auto maxBits = getMaxBits (binary->right , localInfoProvider);
204
+ // if maxBits is negative
205
+ if (maxBits == 32 ) {
206
+ return 32 ;
207
+ }
208
+ return std::max (getMaxBits (binary->left , localInfoProvider), maxBits);
209
+ }
160
210
case ShlInt32: {
161
211
if (auto * shifts = binary->right ->dynCast <Const>()) {
162
212
return std::min (Index (32 ),
@@ -178,6 +228,7 @@ Index getMaxBits(Expression* curr,
178
228
case ShrSInt32: {
179
229
if (auto * shift = binary->right ->dynCast <Const>()) {
180
230
auto maxBits = getMaxBits (binary->left , localInfoProvider);
231
+ // if maxBits is negative
181
232
if (maxBits == 32 ) {
182
233
return 32 ;
183
234
}
@@ -188,7 +239,105 @@ Index getMaxBits(Expression* curr,
188
239
}
189
240
return 32 ;
190
241
}
191
- // 64-bit TODO
242
+ case RotLInt64:
243
+ case RotRInt64:
244
+ case SubInt64:
245
+ return 64 ;
246
+ case AddInt64: {
247
+ auto maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
248
+ auto maxBitsRight = getMaxBits (binary->right , localInfoProvider);
249
+ return std::min (Index (64 ), std::max (maxBitsLeft, maxBitsRight) + 1 );
250
+ }
251
+ case MulInt64: {
252
+ auto maxBitsRight = getMaxBits (binary->right , localInfoProvider);
253
+ auto maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
254
+ return std::min (Index (64 ), maxBitsLeft + maxBitsRight);
255
+ }
256
+ case DivSInt64: {
257
+ if (auto * c = binary->right ->dynCast <Const>()) {
258
+ int32_t maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
259
+ // if maxBitsLeft or right const value is negative
260
+ if (maxBitsLeft == 64 || c->value .geti64 () < 0 ) {
261
+ return 64 ;
262
+ }
263
+ int32_t bitsRight = getMaxBits (c);
264
+ return std::max (0 , maxBitsLeft - bitsRight + 1 );
265
+ }
266
+ return 64 ;
267
+ }
268
+ case DivUInt64: {
269
+ int32_t maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
270
+ if (auto * c = binary->right ->dynCast <Const>()) {
271
+ int32_t bitsRight = getMaxBits (c);
272
+ return std::max (0 , maxBitsLeft - bitsRight + 1 );
273
+ }
274
+ return maxBitsLeft;
275
+ }
276
+ case RemSInt64: {
277
+ if (auto * c = binary->right ->dynCast <Const>()) {
278
+ auto maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
279
+ // if maxBitsLeft is negative
280
+ if (maxBitsLeft == 64 ) {
281
+ return 64 ;
282
+ }
283
+ auto bitsRight = Index (CeilLog2 (c->value .geti64 ()));
284
+ return std::min (maxBitsLeft, bitsRight);
285
+ }
286
+ return 64 ;
287
+ }
288
+ case RemUInt64: {
289
+ if (auto * c = binary->right ->dynCast <Const>()) {
290
+ auto maxBitsLeft = getMaxBits (binary->left , localInfoProvider);
291
+ auto bitsRight = Index (CeilLog2 (c->value .geti64 ()));
292
+ return std::min (maxBitsLeft, bitsRight);
293
+ }
294
+ return 64 ;
295
+ }
296
+ case AndInt64: {
297
+ auto maxBits = getMaxBits (binary->right , localInfoProvider);
298
+ return std::min (getMaxBits (binary->left , localInfoProvider), maxBits);
299
+ }
300
+ case OrInt64:
301
+ case XorInt64: {
302
+ auto maxBits = getMaxBits (binary->right , localInfoProvider);
303
+ // if maxBits is negative
304
+ if (maxBits == 64 ) {
305
+ return 64 ;
306
+ }
307
+ return std::max (getMaxBits (binary->left , localInfoProvider), maxBits);
308
+ }
309
+ case ShlInt64: {
310
+ if (auto * shifts = binary->right ->dynCast <Const>()) {
311
+ auto maxBits = getMaxBits (binary->left , localInfoProvider);
312
+ return std::min (Index (64 ),
313
+ Bits::getEffectiveShifts (shifts) + maxBits);
314
+ }
315
+ return 64 ;
316
+ }
317
+ case ShrUInt64: {
318
+ if (auto * shift = binary->right ->dynCast <Const>()) {
319
+ auto maxBits = getMaxBits (binary->left , localInfoProvider);
320
+ auto shifts =
321
+ std::min (Index (Bits::getEffectiveShifts (shift)),
322
+ maxBits); // can ignore more shifts than zero us out
323
+ return std::max (Index (0 ), maxBits - shifts);
324
+ }
325
+ return 64 ;
326
+ }
327
+ case ShrSInt64: {
328
+ if (auto * shift = binary->right ->dynCast <Const>()) {
329
+ auto maxBits = getMaxBits (binary->left , localInfoProvider);
330
+ // if maxBits is negative
331
+ if (maxBits == 64 ) {
332
+ return 64 ;
333
+ }
334
+ auto shifts =
335
+ std::min (Index (Bits::getEffectiveShifts (shift)),
336
+ maxBits); // can ignore more shifts than zero us out
337
+ return std::max (Index (0 ), maxBits - shifts);
338
+ }
339
+ return 64 ;
340
+ }
192
341
// comparisons
193
342
case EqInt32:
194
343
case NeInt32:
@@ -200,6 +349,7 @@ Index getMaxBits(Expression* curr,
200
349
case GtUInt32:
201
350
case GeSInt32:
202
351
case GeUInt32:
352
+
203
353
case EqInt64:
204
354
case NeInt64:
205
355
case LtSInt64:
@@ -210,12 +360,14 @@ Index getMaxBits(Expression* curr,
210
360
case GtUInt64:
211
361
case GeSInt64:
212
362
case GeUInt64:
363
+
213
364
case EqFloat32:
214
365
case NeFloat32:
215
366
case LtFloat32:
216
367
case LeFloat32:
217
368
case GtFloat32:
218
369
case GeFloat32:
370
+
219
371
case EqFloat64:
220
372
case NeFloat64:
221
373
case LtFloat64:
@@ -240,7 +392,12 @@ Index getMaxBits(Expression* curr,
240
392
case EqZInt64:
241
393
return 1 ;
242
394
case WrapInt64:
395
+ case ExtendUInt32:
243
396
return std::min (Index (32 ), getMaxBits (unary->value , localInfoProvider));
397
+ case ExtendSInt32: {
398
+ auto maxBits = getMaxBits (unary->value , localInfoProvider);
399
+ return maxBits == 32 ? Index (64 ) : maxBits;
400
+ }
244
401
default : {
245
402
}
246
403
}
0 commit comments