Open
Description
Take this example:
#![feature(let_chains)]
#[no_mangle]
pub fn kinda_sort_merge(
query_ids: &[u32],
value_ids: &[u32],
) -> u32
{
let mut result = 0;
let mut v_iter = value_ids.iter();
let mut current = v_iter.next(); // May be None
for &q_id in query_ids.iter() {
while let Some(&v_id) = current && v_id < q_id {
current = v_iter.next();
}
match current {
Some(&v_id) => {result += v_id * q_id;}
None => {break;}
}
// After the first iteration, Current is guaranteed to be Some at the start of every iteration.
}
result
}
You'd expect in the hot loop of the resulting assembly to have 2 possible jumps to return: one for the end of query_ids
, and one for the end of value_ids
. However, there are 3 of them:
kinda_sort_merge:
test rsi, rsi
je .LBB0_8 ; Check for query_ids.len() != 0
lea r8, [rdx + 4*rcx]
lea rsi, [rdi + 4*rsi]
xor eax, eax
test rcx, rcx
setne al
lea r9, [rdx + 4*rax]
cmove rdx, rcx
xor ecx, ecx
.LBB0_2: ; Hot loop
test rdx, rdx
je .LBB0_10 ; 1, check for v_iter before while loop
mov r10d, dword ptr [rdi]
add rdi, 4
mov rax, r9
.LBB0_4:
mov r9, rax
mov eax, dword ptr [rdx]
cmp eax, r10d
jae .LBB0_6
lea rax, [r9 + 4]
mov rdx, r9
cmp r9, r8
jne .LBB0_4
jmp .LBB0_10 ; 2, check for v_iter during while loop
.LBB0_6:
imul eax, r10d
add eax, ecx
mov ecx, eax
cmp rdi, rsi
jne .LBB0_2
jmp .LBB0_9 ; 3, check for query_ids
.LBB0_10:
mov eax, ecx
ret
.LBB0_8:
xor eax, eax
.LBB0_9:
ret
By adding assert_unchecked()
:
#![feature(let_chains)]
use std::hint::assert_unchecked;
#[no_mangle]
pub fn kinda_sort_merge(
query_ids: &[u32],
value_ids: &[u32],
) -> u32
{
let mut result = 0;
let mut v_iter = value_ids.iter();
let mut current = v_iter.next();
let b = current.is_some();
for &q_id in query_ids.iter() {
if b {unsafe {assert_unchecked(current.is_some())}}
while let Some(&v_id) = current && v_id < q_id {
current = v_iter.next();
}
if !b {unsafe {assert_unchecked(current.is_none())}}
match current {
Some(&v_id) => {result += v_id * q_id;}
None => {break;}
}
}
result
}
The assembly quality improves:
kinda_sort_merge:
test rcx, rcx
sete r8b
test rsi, rsi
sete r9b
xor eax, eax
or r9b, r8b
jne .LBB0_7 ; Check for both's len different from 0
lea rcx, [rdx + 4*rcx]
lea rsi, [rdi + 4*rsi]
lea r9, [rdx + 4]
xor r8d, r8d
.LBB0_2: ; Hot loop
mov r10d, dword ptr [rdi]
add rdi, 4
mov rax, r9
.LBB0_3:
mov r9, rax
mov eax, dword ptr [rdx]
cmp eax, r10d
jae .LBB0_4
lea rax, [r9 + 4]
mov rdx, r9
cmp r9, rcx
jne .LBB0_3
jmp .LBB0_6 ; 1, check for v_iter during while loop
.LBB0_4:
imul eax, r10d
add eax, r8d
mov r8d, eax
cmp rdi, rsi
jne .LBB0_2
jmp .LBB0_7 ; 2, check for v_iter
.LBB0_6:
mov eax, r8d
.LBB0_7:
ret