Skip to content

Generated Result propagation code is needlessly complex #81146

Open
@tgnottingham

Description

@tgnottingham

This case was found in #80463. Replacing some infallible code with code that needed to propagate Results led to some performance issues. It looked like the Result propagation code was less efficient than it could be. The code below approximately reproduces the original issue.

use std::fs::File;
use std::io::Write;

#[derive(Debug)]
enum Error {
    A(u32),
    B(u64),
}

#[inline(never)]
fn println() {
    println!();
}

#[inline(never)]
fn bar() -> Result<(), Error> {
    let mut file = File::create("x.bin").unwrap();
    let result = file.write_all(&[1]);

    if let Err(err) = result {
        if err.to_string().len() == 0 {
            return Err(Error::A(10))
        }
    }

    return Ok(())
}

#[inline(never)]
fn foo() -> Result<(), Error> {
    bar()?;
    println();
    Ok(())
}

fn main() {
    foo().unwrap();
}

The assembly for foo (release build for x86_64 Linux target) is what we're interested in:

<x::foo>:
	       push   %rbx              ; We'll be using rbx, and it's callee-save, so save it
	       callq  <x::bar>          ; Call bar
	       mov    %rax,%rbx         ; Preserve rax before println
	       cmp    $0x2,%ebx         ; Did bar return Ok?
	   /-- jne                      ; If not, take jump
	   |   callq  <x::println>      ; Otherwise, call println
	/--|-- jmp                      ; And jump
	|  \-> mov    %rbx,%rcx         ; Silliness ensues
	|      shrd   $0x20,%rdx,%rcx
	|      shr    $0x20,%rdx
	\----> shld   $0x20,%rcx,%rdx
	       shl    $0x20,%rcx
	       mov    %ebx,%eax
	       or     %rcx,%rax         ; rax and rdx now contain Result to return
	       pop    %rbx              ; Restore rbx
	       retq                     ; Return

I believe this could optimize down to something like this instead:

<x::foo>:
	       push   %rcx              ; Ensure 16 byte stack alignment for calls
	       callq  <x::bar>
	       cmp    $0x2,%eax
	   /-- jne
	   |   callq  <x::println>
	   |   move   %0x2,%eax
	   \-> pop    %rcx
	       retq   

When bar returns an Err, all we need to do is propagate the 16 byte Result, already contained in rax and rdx, which is exactly where foo needs to return it from.

When bar succeeds, all we need to do is call println and then return Result::Ok (set the appropriate enum tag in rax).

You can step through the original assembly and see that this all happens, but that it's done in a really convoluted way:

  • In the bar returns Err case, the shifting, etc. is such that rax and rdx contain the same Err value returned by bar, as desired. But rax and rdx already contained those values after the bar call, so all of that shifting, etc. was extra work.

  • In the bar returns Ok case, we do set the enum tag to that of the Ok variant by way of mov %ebx, %eax, but we also modify rdx needlessly (in the Err(value) case, rdx contains the value, but it's not needed in Ok case), and we modify rcx and or it into rax needlessly (rcx contains junk in this case, but oring it in doesn't affect the tag value, since the lower bits of rcx have been shifted out). All of this could have been replaced with a mov $0x2,%eax.

LLVM is either missing the optimization or we are not providing enough information to enable it to do the optimization.

LLVM IR after optimization passes
; main::foo
; Function Attrs: noinline nonlazybind uwtable
define internal fastcc i128 @_ZN4main3foo17h790eccd96863c176E() unnamed_addr #0 {
start:
; call main::bar
  %0 = tail call fastcc i128 @_ZN4main3bar17h12f571ad2d6bd47aE()
  %.sroa.020.0.extract.trunc = trunc i128 %0 to i32
  %1 = icmp eq i32 %.sroa.020.0.extract.trunc, 2
  br i1 %1, label %bb3, label %bb5

bb3:                                              ; preds = %start
; call main::println
  tail call fastcc void @_ZN4main7println17h4b182729f4c1c974E()
  br label %bb9

bb5:                                              ; preds = %start
  %.sroa.4.0.extract.shift33 = lshr i128 %0, 32
  %.sroa.4.0.extract.trunc34 = trunc i128 %.sroa.4.0.extract.shift33 to i96
  br label %bb9

bb9:                                              ; preds = %bb3, %bb5
  %.sroa.3.sroa.0.0 = phi i96 [ undef, %bb3 ], [ %.sroa.4.0.extract.trunc34, %bb5 ]
  %.sroa.3.0.insert.ext = zext i96 %.sroa.3.sroa.0.0 to i128
  %.sroa.3.0.insert.shift = shl nuw i128 %.sroa.3.0.insert.ext, 32
  %.sroa.0.0.insert.ext = and i128 %0, 4294967295
  %.sroa.0.0.insert.insert = or i128 %.sroa.3.0.insert.shift, %.sroa.0.0.insert.ext
  ret i128 %.sroa.0.0.insert.insert
}
LLVM IR before optimization passes
; main::foo
; Function Attrs: noinline nonlazybind uwtable
define internal i128 @_ZN4main3foo17h790eccd96863c176E() unnamed_addr #2 {
start:
  %0 = alloca i128, align 8
  %1 = alloca i128, align 8
  %2 = alloca i128, align 8
  %3 = alloca i128, align 8
  %_6 = alloca %Error, align 8
  %_5 = alloca %Error, align 8
  %err = alloca %Error, align 8
  %_2 = alloca %"std::result::Result<(), Error>", align 8
  %_1 = alloca %"std::result::Result<(), Error>", align 8
  %4 = alloca %"std::result::Result<(), Error>", align 8
  %5 = bitcast %"std::result::Result<(), Error>"* %_1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %5)
  %6 = bitcast %"std::result::Result<(), Error>"* %_2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %6)
; call main::bar
  %7 = call i128 @_ZN4main3bar17h12f571ad2d6bd47aE()
  %8 = bitcast i128* %3 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %8)
  store i128 %7, i128* %3, align 8
  %9 = bitcast %"std::result::Result<(), Error>"* %_2 to i8*
  %10 = bitcast i128* %3 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %9, i8* align 8 %10, i64 16, i1 false)
  %11 = bitcast i128* %3 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %11)
  br label %bb1

bb1:                                              ; preds = %start
; call <core::result::Result<T,E> as core::ops::try::Try>::into_result
  %12 = call i128 @"_ZN73_$LT$core..result..Result$LT$T$C$E$GT$$u20$as$u20$core..ops..try..Try$GT$11into_result17hef5a3592992010b5E"(%"std::result::Result<(), Error>"* noalias nocapture dereferenceable(16) %_2)
  %13 = bitcast i128* %2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %13)
  store i128 %12, i128* %2, align 8
  %14 = bitcast %"std::result::Result<(), Error>"* %_1 to i8*
  %15 = bitcast i128* %2 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %14, i8* align 8 %15, i64 16, i1 false)
  %16 = bitcast i128* %2 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %16)
  br label %bb2

bb2:                                              ; preds = %bb1
  %17 = bitcast %"std::result::Result<(), Error>"* %_2 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %17)
  %18 = bitcast %"std::result::Result<(), Error>"* %_1 to i32*
  %19 = load i32, i32* %18, align 8, !range !11
  %20 = sub i32 %19, 2
  %21 = icmp eq i32 %20, 0
  %_3 = select i1 %21, i64 0, i64 1
  switch i64 %_3, label %bb4 [
    i64 0, label %bb3
    i64 1, label %bb5
  ]

bb3:                                              ; preds = %bb2
  %22 = bitcast %"std::result::Result<(), Error>"* %_1 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %22)
; call main::println
  call void @_ZN4main7println17h4b182729f4c1c974E()
  br label %bb8

bb4:                                              ; preds = %bb2
  unreachable

bb5:                                              ; preds = %bb2
  %23 = bitcast %Error* %err to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %23)
  %24 = bitcast %"std::result::Result<(), Error>"* %_1 to %"std::result::Result<(), Error>::Err"*
  %25 = bitcast %"std::result::Result<(), Error>::Err"* %24 to %Error*
  %26 = bitcast %Error* %err to i8*
  %27 = bitcast %Error* %25 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %26, i8* align 8 %27, i64 16, i1 false)
  %28 = bitcast %Error* %_5 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %28)
  %29 = bitcast %Error* %_6 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %29)
  %30 = bitcast %Error* %_6 to i8*
  %31 = bitcast %Error* %err to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %30, i8* align 8 %31, i64 16, i1 false)
; call <T as core::convert::From<T>>::from
  %32 = call i128 @"_ZN50_$LT$T$u20$as$u20$core..convert..From$LT$T$GT$$GT$4from17h9c2936a7ad0bb0a3E"(%Error* noalias nocapture dereferenceable(16) %_6)
  %33 = bitcast i128* %1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %33)
  store i128 %32, i128* %1, align 8
  %34 = bitcast %Error* %_5 to i8*
  %35 = bitcast i128* %1 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %34, i8* align 8 %35, i64 16, i1 false)
  %36 = bitcast i128* %1 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %36)
  br label %bb6

bb6:                                              ; preds = %bb5
  %37 = bitcast %Error* %_6 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %37)
; call <core::result::Result<T,E> as core::ops::try::Try>::from_error
  %38 = call i128 @"_ZN73_$LT$core..result..Result$LT$T$C$E$GT$$u20$as$u20$core..ops..try..Try$GT$10from_error17hc0c7c23bfd8ab5f4E"(%Error* noalias nocapture dereferenceable(16) %_5)
  %39 = bitcast i128* %0 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* %39)
  store i128 %38, i128* %0, align 8
  %40 = bitcast %"std::result::Result<(), Error>"* %4 to i8*
  %41 = bitcast i128* %0 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %40, i8* align 8 %41, i64 16, i1 false)
  %42 = bitcast i128* %0 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %42)
  br label %bb7

bb7:                                              ; preds = %bb6
  %43 = bitcast %Error* %_5 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %43)
  %44 = bitcast %Error* %err to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %44)
  %45 = bitcast %"std::result::Result<(), Error>"* %_1 to i8*
  call void @llvm.lifetime.end.p0i8(i64 16, i8* %45)
  br label %bb9

bb8:                                              ; preds = %bb3
  %46 = bitcast %"std::result::Result<(), Error>"* %4 to %"std::result::Result<(), Error>::Ok"*
  %47 = bitcast %"std::result::Result<(), Error>::Ok"* %46 to {}*
  %48 = bitcast %"std::result::Result<(), Error>"* %4 to i32*
  store i32 2, i32* %48, align 8
  br label %bb9

bb9:                                              ; preds = %bb8, %bb7
  %49 = bitcast %"std::result::Result<(), Error>"* %4 to i128*
  %50 = load i128, i128* %49, align 8
  ret i128 %50
}

Also, replacing ? with manual propagation makes a difference, though it still produces less than optimal code:

#[inline(never)]
fn foo() -> Result<(), Error> {
    let r = bar();

    if r.is_err() {
        return r;
    }

    println();
    Ok(())
}
<x::foo>:
	    push   %r14
	    push   %rbx
	    push   %rax
	    callq  <x::bar>
	    mov    %rax,%rbx
	    mov    %rdx,%r14
	    cmp    $0x2,%ebx
	/-- jne    <x::foo+0x19>
	|   callq  <x::println>
	\-> mov    %rbx,%rax
	    mov    %r14,%rdx
	    add    $0x8,%rsp
	    pop    %rbx
	    pop    %r14
	    retq   
rustc 1.49.0 (e1884a8e3 2020-12-29)
binary: rustc
commit-hash: e1884a8e3c3e813aada8254edfa120e85bf5ffca
commit-date: 2020-12-29
host: x86_64-unknown-linux-gnu
release: 1.49.0

@rustbot label T-compiler A-LLVM I-slow

Metadata

Metadata

Assignees

No one assigned

    Labels

    A-LLVMArea: Code generation parts specific to LLVM. Both correctness bugs and optimization-related issues.A-result-optionArea: Result and Option combinatorsC-optimizationCategory: An issue highlighting optimization opportunities or PRs implementing suchI-slowIssue: Problems and improvements with respect to performance of generated code.T-compilerRelevant to the compiler team, which will review and decide on the PR/issue.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions