Description
This case was found in #80463. Replacing some infallible code with code that needed to propagate Result
s led to some performance issues. It looked like the Result
propagation code was less efficient than it could be. The code below approximately reproduces the original issue.
use std::fs::File;
use std::io::Write;
#[derive(Debug)]
enum Error {
A(u32),
B(u64),
}
#[inline(never)]
fn println() {
println!();
}
#[inline(never)]
fn bar() -> Result<(), Error> {
let mut file = File::create("x.bin").unwrap();
let result = file.write_all(&[1]);
if let Err(err) = result {
if err.to_string().len() == 0 {
return Err(Error::A(10))
}
}
return Ok(())
}
#[inline(never)]
fn foo() -> Result<(), Error> {
bar()?;
println();
Ok(())
}
fn main() {
foo().unwrap();
}
The assembly for foo
(release build for x86_64 Linux target) is what we're interested in:
<x::foo>:
push %rbx ; We'll be using rbx, and it's callee-save, so save it
callq <x::bar> ; Call bar
mov %rax,%rbx ; Preserve rax before println
cmp $0x2,%ebx ; Did bar return Ok?
/-- jne ; If not, take jump
| callq <x::println> ; Otherwise, call println
/--|-- jmp ; And jump
| \-> mov %rbx,%rcx ; Silliness ensues
| shrd $0x20,%rdx,%rcx
| shr $0x20,%rdx
\----> shld $0x20,%rcx,%rdx
shl $0x20,%rcx
mov %ebx,%eax
or %rcx,%rax ; rax and rdx now contain Result to return
pop %rbx ; Restore rbx
retq ; Return
I believe this could optimize down to something like this instead:
<x::foo>:
push %rcx ; Ensure 16 byte stack alignment for calls
callq <x::bar>
cmp $0x2,%eax
/-- jne
| callq <x::println>
| move %0x2,%eax
\-> pop %rcx
retq
When bar
returns an Err
, all we need to do is propagate the 16 byte Result
, already contained in rax
and rdx
, which is exactly where foo
needs to return it from.
When bar
succeeds, all we need to do is call println
and then return Result::Ok
(set the appropriate enum tag in rax
).
You can step through the original assembly and see that this all happens, but that it's done in a really convoluted way:
-
In the
bar
returnsErr
case, the shifting, etc. is such thatrax
andrdx
contain the sameErr
value returned bybar
, as desired. Butrax
andrdx
already contained those values after thebar
call, so all of that shifting, etc. was extra work. -
In the
bar
returnsOk
case, we do set the enum tag to that of theOk
variant by way ofmov %ebx, %eax
, but we also modifyrdx
needlessly (in theErr(value)
case,rdx
contains thevalue
, but it's not needed inOk
case), and we modifyrcx
andor
it intorax
needlessly (rcx
contains junk in this case, butor
ing it in doesn't affect the tag value, since the lower bits ofrcx
have been shifted out). All of this could have been replaced with amov $0x2,%eax
.
LLVM is either missing the optimization or we are not providing enough information to enable it to do the optimization.
LLVM IR after optimization passes
; main::foo
; Function Attrs: noinline nonlazybind uwtable
define internal fastcc i128 @_ZN4main3foo17h790eccd96863c176E() unnamed_addr #0 {
start:
; call main::bar
%0 = tail call fastcc i128 @_ZN4main3bar17h12f571ad2d6bd47aE()
%.sroa.020.0.extract.trunc = trunc i128 %0 to i32
%1 = icmp eq i32 %.sroa.020.0.extract.trunc, 2
br i1 %1, label %bb3, label %bb5
bb3: ; preds = %start
; call main::println
tail call fastcc void @_ZN4main7println17h4b182729f4c1c974E()
br label %bb9
bb5: ; preds = %start
%.sroa.4.0.extract.shift33 = lshr i128 %0, 32
%.sroa.4.0.extract.trunc34 = trunc i128 %.sroa.4.0.extract.shift33 to i96
br label %bb9
bb9: ; preds = %bb3, %bb5
%.sroa.3.sroa.0.0 = phi i96 [ undef, %bb3 ], [ %.sroa.4.0.extract.trunc34, %bb5 ]
%.sroa.3.0.insert.ext = zext i96 %.sroa.3.sroa.0.0 to i128
%.sroa.3.0.insert.shift = shl nuw i128 %.sroa.3.0.insert.ext, 32
%.sroa.0.0.insert.ext = and i128 %0, 4294967295
%.sroa.0.0.insert.insert = or i128 %.sroa.3.0.insert.shift, %.sroa.0.0.insert.ext
ret i128 %.sroa.0.0.insert.insert
}
LLVM IR before optimization passes
; main::foo
; Function Attrs: noinline nonlazybind uwtable
define internal i128 @_ZN4main3foo17h790eccd96863c176E() unnamed_addr #2 {
start:
%0 = alloca i128, align 8
%1 = alloca i128, align 8
%2 = alloca i128, align 8
%3 = alloca i128, align 8
%_6 = alloca %Error, align 8
%_5 = alloca %Error, align 8
%err = alloca %Error, align 8
%_2 = alloca %"std::result::Result<(), Error>", align 8
%_1 = alloca %"std::result::Result<(), Error>", align 8
%4 = alloca %"std::result::Result<(), Error>", align 8
%5 = bitcast %"std::result::Result<(), Error>"* %_1 to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %5)
%6 = bitcast %"std::result::Result<(), Error>"* %_2 to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %6)
; call main::bar
%7 = call i128 @_ZN4main3bar17h12f571ad2d6bd47aE()
%8 = bitcast i128* %3 to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %8)
store i128 %7, i128* %3, align 8
%9 = bitcast %"std::result::Result<(), Error>"* %_2 to i8*
%10 = bitcast i128* %3 to i8*
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %9, i8* align 8 %10, i64 16, i1 false)
%11 = bitcast i128* %3 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %11)
br label %bb1
bb1: ; preds = %start
; call <core::result::Result<T,E> as core::ops::try::Try>::into_result
%12 = call i128 @"_ZN73_$LT$core..result..Result$LT$T$C$E$GT$$u20$as$u20$core..ops..try..Try$GT$11into_result17hef5a3592992010b5E"(%"std::result::Result<(), Error>"* noalias nocapture dereferenceable(16) %_2)
%13 = bitcast i128* %2 to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %13)
store i128 %12, i128* %2, align 8
%14 = bitcast %"std::result::Result<(), Error>"* %_1 to i8*
%15 = bitcast i128* %2 to i8*
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %14, i8* align 8 %15, i64 16, i1 false)
%16 = bitcast i128* %2 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %16)
br label %bb2
bb2: ; preds = %bb1
%17 = bitcast %"std::result::Result<(), Error>"* %_2 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %17)
%18 = bitcast %"std::result::Result<(), Error>"* %_1 to i32*
%19 = load i32, i32* %18, align 8, !range !11
%20 = sub i32 %19, 2
%21 = icmp eq i32 %20, 0
%_3 = select i1 %21, i64 0, i64 1
switch i64 %_3, label %bb4 [
i64 0, label %bb3
i64 1, label %bb5
]
bb3: ; preds = %bb2
%22 = bitcast %"std::result::Result<(), Error>"* %_1 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %22)
; call main::println
call void @_ZN4main7println17h4b182729f4c1c974E()
br label %bb8
bb4: ; preds = %bb2
unreachable
bb5: ; preds = %bb2
%23 = bitcast %Error* %err to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %23)
%24 = bitcast %"std::result::Result<(), Error>"* %_1 to %"std::result::Result<(), Error>::Err"*
%25 = bitcast %"std::result::Result<(), Error>::Err"* %24 to %Error*
%26 = bitcast %Error* %err to i8*
%27 = bitcast %Error* %25 to i8*
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %26, i8* align 8 %27, i64 16, i1 false)
%28 = bitcast %Error* %_5 to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %28)
%29 = bitcast %Error* %_6 to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %29)
%30 = bitcast %Error* %_6 to i8*
%31 = bitcast %Error* %err to i8*
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %30, i8* align 8 %31, i64 16, i1 false)
; call <T as core::convert::From<T>>::from
%32 = call i128 @"_ZN50_$LT$T$u20$as$u20$core..convert..From$LT$T$GT$$GT$4from17h9c2936a7ad0bb0a3E"(%Error* noalias nocapture dereferenceable(16) %_6)
%33 = bitcast i128* %1 to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %33)
store i128 %32, i128* %1, align 8
%34 = bitcast %Error* %_5 to i8*
%35 = bitcast i128* %1 to i8*
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %34, i8* align 8 %35, i64 16, i1 false)
%36 = bitcast i128* %1 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %36)
br label %bb6
bb6: ; preds = %bb5
%37 = bitcast %Error* %_6 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %37)
; call <core::result::Result<T,E> as core::ops::try::Try>::from_error
%38 = call i128 @"_ZN73_$LT$core..result..Result$LT$T$C$E$GT$$u20$as$u20$core..ops..try..Try$GT$10from_error17hc0c7c23bfd8ab5f4E"(%Error* noalias nocapture dereferenceable(16) %_5)
%39 = bitcast i128* %0 to i8*
call void @llvm.lifetime.start.p0i8(i64 16, i8* %39)
store i128 %38, i128* %0, align 8
%40 = bitcast %"std::result::Result<(), Error>"* %4 to i8*
%41 = bitcast i128* %0 to i8*
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %40, i8* align 8 %41, i64 16, i1 false)
%42 = bitcast i128* %0 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %42)
br label %bb7
bb7: ; preds = %bb6
%43 = bitcast %Error* %_5 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %43)
%44 = bitcast %Error* %err to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %44)
%45 = bitcast %"std::result::Result<(), Error>"* %_1 to i8*
call void @llvm.lifetime.end.p0i8(i64 16, i8* %45)
br label %bb9
bb8: ; preds = %bb3
%46 = bitcast %"std::result::Result<(), Error>"* %4 to %"std::result::Result<(), Error>::Ok"*
%47 = bitcast %"std::result::Result<(), Error>::Ok"* %46 to {}*
%48 = bitcast %"std::result::Result<(), Error>"* %4 to i32*
store i32 2, i32* %48, align 8
br label %bb9
bb9: ; preds = %bb8, %bb7
%49 = bitcast %"std::result::Result<(), Error>"* %4 to i128*
%50 = load i128, i128* %49, align 8
ret i128 %50
}
Also, replacing ?
with manual propagation makes a difference, though it still produces less than optimal code:
#[inline(never)]
fn foo() -> Result<(), Error> {
let r = bar();
if r.is_err() {
return r;
}
println();
Ok(())
}
<x::foo>:
push %r14
push %rbx
push %rax
callq <x::bar>
mov %rax,%rbx
mov %rdx,%r14
cmp $0x2,%ebx
/-- jne <x::foo+0x19>
| callq <x::println>
\-> mov %rbx,%rax
mov %r14,%rdx
add $0x8,%rsp
pop %rbx
pop %r14
retq
rustc 1.49.0 (e1884a8e3 2020-12-29)
binary: rustc
commit-hash: e1884a8e3c3e813aada8254edfa120e85bf5ffca
commit-date: 2020-12-29
host: x86_64-unknown-linux-gnu
release: 1.49.0
@rustbot label T-compiler A-LLVM I-slow