Skip to content

Commit d12e423

Browse files
authored
Simplify ireduce of modular operators, when it lowers the total instructions (#7719)
1 parent 37c06fd commit d12e423

File tree

5 files changed

+242
-0
lines changed

5 files changed

+242
-0
lines changed

cranelift/codegen/src/opts/arithmetic.isle

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,19 @@
151151
(subsume (bxor ty (bxor ty a b) (bxor ty c d))))
152152
(rule (simplify (bxor ty (bxor ty (bxor ty a b) c) d))
153153
(subsume (bxor ty (bxor ty a b) (bxor ty c d))))
154+
155+
;; Detect people open-coding `mulhi`: (x as big * y as big) >> bits
156+
;; LLVM doesn't have an intrinsic for it, so you'll see it in code like
157+
;; <https://github.com/rust-lang/rust/blob/767453eb7ca188e991ac5568c17b984dd4893e77/library/core/src/num/mod.rs#L174-L180>
158+
(rule (simplify (sshr ty (imul ty (sextend _ x@(value_type half_ty))
159+
(sextend _ y@(value_type half_ty)))
160+
(iconst_u _ k)))
161+
(if-let $true (ty_equal half_ty (ty_half_width ty)))
162+
(if-let $true (u64_eq k (ty_bits_u64 half_ty)))
163+
(sextend ty (smulhi half_ty x y)))
164+
(rule (simplify (ushr ty (imul ty (uextend _ x@(value_type half_ty))
165+
(uextend _ y@(value_type half_ty)))
166+
(iconst_u _ k)))
167+
(if-let $true (ty_equal half_ty (ty_half_width ty)))
168+
(if-let $true (u64_eq k (ty_bits_u64 half_ty)))
169+
(uextend ty (umulhi half_ty x y)))

cranelift/codegen/src/opts/extends.isle

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,43 @@
5858
(uextend bigty (bor smallty x y)))
5959
(rule (simplify (bxor bigty (uextend _ x@(value_type smallty)) (uextend _ y@(value_type smallty))))
6060
(uextend bigty (bxor smallty x y)))
61+
62+
;; Matches values where `ireducing` them will not actually introduce another
63+
;; instruction, since other rules will collapse them with the reduction.
64+
(decl pure multi will_simplify_with_ireduce (Value) Value)
65+
(rule (will_simplify_with_ireduce x@(uextend _ _)) x)
66+
(rule (will_simplify_with_ireduce x@(sextend _ _)) x)
67+
(rule (will_simplify_with_ireduce x@(iconst _ _)) x)
68+
(rule (will_simplify_with_ireduce x@(unary_op _ _ a))
69+
(if-let _ (will_simplify_with_ireduce a))
70+
x)
71+
(rule (will_simplify_with_ireduce x@(binary_op _ _ a b))
72+
(if-let _ (will_simplify_with_ireduce a))
73+
(if-let _ (will_simplify_with_ireduce b))
74+
x)
75+
76+
;; Matches values where the high bits of the input don't affect lower bits of
77+
;; the output, and thus the inputs can be reduced before the operation rather
78+
;; than doing the wide operation then reducing afterwards.
79+
(decl pure multi reducible_modular_op (Value) Value)
80+
(rule (reducible_modular_op x@(ineg _ _)) x)
81+
(rule (reducible_modular_op x@(bnot _ _)) x)
82+
(rule (reducible_modular_op x@(iadd _ _ _)) x)
83+
(rule (reducible_modular_op x@(isub _ _ _)) x)
84+
(rule (reducible_modular_op x@(imul _ _ _)) x)
85+
(rule (reducible_modular_op x@(bor _ _ _)) x)
86+
(rule (reducible_modular_op x@(bxor _ _ _)) x)
87+
(rule (reducible_modular_op x@(band _ _ _)) x)
88+
89+
;; Replace `(small)(x OP y)` with `(small)x OP (small)y` in cases where that's
90+
;; legal and it reduces the total number of instructions since the reductions
91+
;; to the arguments simplify further.
92+
(rule (simplify (ireduce smallty val@(unary_op _ op x)))
93+
(if-let _ (reducible_modular_op val))
94+
(if-let _ (will_simplify_with_ireduce x))
95+
(unary_op smallty op (ireduce smallty x)))
96+
(rule (simplify (ireduce smallty val@(binary_op _ op x y)))
97+
(if-let _ (reducible_modular_op val))
98+
(if-let _ (will_simplify_with_ireduce x))
99+
(if-let _ (will_simplify_with_ireduce y))
100+
(binary_op smallty op (ireduce smallty x) (ireduce smallty y)))

cranelift/codegen/src/prelude_opt.isle

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,15 @@
120120
(extractor (sextend_maybe ty val) (sextend_maybe_etor ty val))
121121
(rule 0 (sextend_maybe ty val) (sextend ty val))
122122
(rule 1 (sextend_maybe ty val@(value_type ty)) val)
123+
124+
(decl unary_op (Type Opcode Value) Value)
125+
(extractor (unary_op ty opcode x)
126+
(inst_data ty (InstructionData.Unary opcode x)))
127+
(rule (unary_op ty opcode x)
128+
(make_inst ty (InstructionData.Unary opcode x)))
129+
130+
(decl binary_op (Type Opcode Value Value) Value)
131+
(extractor (binary_op ty opcode x y)
132+
(inst_data ty (InstructionData.Binary opcode (value_array_2 x y))))
133+
(rule (binary_op ty opcode x y)
134+
(make_inst ty (InstructionData.Binary opcode (value_array_2_ctor x y))))

cranelift/filetests/filetests/egraph/arithmetic.clif

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,108 @@ block0(v1: f32, v2: f32):
250250

251251
; check: v6 = fmul v1, v2
252252
; check: return v6
253+
254+
function %manual_smulhi_32(i32, i32) -> i32 {
255+
block0(v0: i32, v1: i32):
256+
v2 = sextend.i64 v0
257+
v3 = sextend.i64 v1
258+
v4 = imul v2, v3
259+
v5 = iconst.i32 32
260+
v6 = sshr v4, v5
261+
v7 = ireduce.i32 v6
262+
return v7
263+
}
264+
265+
; check: v8 = smulhi v0, v1
266+
; check: return v8
267+
268+
function %manual_smulhi_64(i64, i64) -> i64 {
269+
block0(v0: i64, v1: i64):
270+
v2 = sextend.i128 v0
271+
v3 = sextend.i128 v1
272+
v4 = imul v2, v3
273+
v5 = iconst.i32 64
274+
v6 = sshr v4, v5
275+
v7 = ireduce.i64 v6
276+
return v7
277+
}
278+
279+
; check: v8 = smulhi v0, v1
280+
; check: return v8
281+
282+
function %manual_umulhi_32(i32, i32) -> i32 {
283+
block0(v0: i32, v1: i32):
284+
v2 = uextend.i64 v0
285+
v3 = uextend.i64 v1
286+
v4 = imul v2, v3
287+
v5 = iconst.i32 32
288+
v6 = ushr v4, v5
289+
v7 = ireduce.i32 v6
290+
return v7
291+
}
292+
293+
; check: v8 = umulhi v0, v1
294+
; check: return v8
295+
296+
function %manual_umulhi_64(i64, i64) -> i64 {
297+
block0(v0: i64, v1: i64):
298+
v2 = uextend.i128 v0
299+
v3 = uextend.i128 v1
300+
v4 = imul v2, v3
301+
v5 = iconst.i32 64
302+
v6 = ushr v4, v5
303+
v7 = ireduce.i64 v6
304+
return v7
305+
}
306+
307+
; check: v8 = umulhi v0, v1
308+
; check: return v8
309+
310+
function %u64_widening_mul(i64, i64, i64) {
311+
block0(v0: i64, v1: i64, v2: i64):
312+
v3 = uextend.i128 v1
313+
v4 = uextend.i128 v2
314+
v5 = imul v3, v4
315+
v6 = iconst.i32 64
316+
v7 = ushr v5, v6
317+
v8 = ireduce.i64 v7
318+
v9 = ireduce.i64 v5
319+
store.i64 v9, v0
320+
store.i64 v8, v0+8
321+
return
322+
}
323+
324+
; check: v18 = imul v1, v2
325+
; check: store v18, v0
326+
; check: v10 = umulhi v1, v2
327+
; check: store v10, v0+8
328+
329+
function %char_plus_one(i8) -> i8 {
330+
block0(v0: i8):
331+
v1 = sextend.i32 v0
332+
v2 = iconst.i32 257
333+
v3 = iadd v1, v2
334+
v4 = ireduce.i8 v3
335+
return v4
336+
}
337+
338+
; check: v8 = iconst.i8 1
339+
; check: v9 = iadd v0, v8 ; v8 = 1
340+
; check: return v9
341+
342+
;; Adding three `short`s together and storing them in a `short`,
343+
;; which in C involves extending them to `int`s in the middle.
344+
function %extend_iadd_iadd_reduce(i16, i16, i16) -> i16 {
345+
block0(v0: i16, v1: i16, v2: i16):
346+
v3 = sextend.i32 v0
347+
v4 = sextend.i32 v1
348+
v5 = sextend.i32 v2
349+
v6 = iadd v3, v4
350+
v7 = iadd v6, v5
351+
v8 = ireduce.i16 v7
352+
return v8
353+
}
354+
355+
; check: v14 = iadd v0, v1
356+
; check: v18 = iadd v14, v2
357+
; check: return v18

cranelift/filetests/filetests/egraph/extends.clif

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,72 @@ block0(v0: i8):
118118
; check: v5 = icmp ne v0, v4
119119
; check: return v5
120120

121+
function %extend_imul_reduce(i64, i64) -> i64 {
122+
block0(v0: i64, v1: i64):
123+
v2 = uextend.i128 v0
124+
v3 = uextend.i128 v1
125+
v4 = imul v2, v3
126+
v5 = ireduce.i64 v4
127+
return v5
128+
}
129+
130+
; check: v10 = imul v0, v1
131+
; check: return v10
132+
133+
function %extend_iadd_reduce(i16, i16) -> i16 {
134+
block0(v0: i16, v1: i16):
135+
v2 = sextend.i32 v0
136+
v3 = sextend.i32 v1
137+
v4 = iadd v2, v3
138+
v5 = ireduce.i16 v4
139+
return v5
140+
}
141+
142+
; check: v10 = iadd v0, v1
143+
; check: return v10
144+
145+
function %extend_bxor_reduce(i64, i64) -> i64 {
146+
block0(v0: i64, v1: i64):
147+
v2 = uextend.i128 v0
148+
v3 = uextend.i128 v1
149+
v4 = bxor v2, v3
150+
v5 = ireduce.i64 v4
151+
return v5
152+
}
153+
154+
; check: v6 = bxor v0, v1
155+
; check: return v6
156+
157+
function %extend_band_reduce(i16, i16) -> i16 {
158+
block0(v0: i16, v1: i16):
159+
v2 = sextend.i32 v0
160+
v3 = sextend.i32 v1
161+
v4 = band v2, v3
162+
v5 = ireduce.i16 v4
163+
return v5
164+
}
165+
166+
; check: v10 = band v0, v1
167+
; check: return v10
168+
169+
function %extend_ineg_reduce(i64) -> i64 {
170+
block0(v0: i64):
171+
v1 = sextend.i128 v0
172+
v2 = ineg v1
173+
v3 = ireduce.i64 v2
174+
return v3
175+
}
176+
177+
; check: v6 = ineg v0
178+
; check: return v6
179+
180+
function %extend_bnot_reduce(i16) -> i16 {
181+
block0(v0: i16):
182+
v1 = uextend.i32 v0
183+
v2 = bnot v1
184+
v3 = ireduce.i16 v2
185+
return v3
186+
}
187+
188+
; check: v6 = bnot v0
189+
; check: return v6

0 commit comments

Comments
 (0)