Skip to content

Investigate possible optimization opportunities for autodiff code with control flow #68901

Open
@asl

Description

@asl

Recent autodiff changes that removed use of pullback structs in favor of direct pullback application to nested calls enabled lots of possible inlining and specialization opportunities in VJPs / pullbacks.

However, the situation is not that great for the functions with control flow as generated code with enums / tuples effectively prevents inlining and function substitution as forward path becomes quite "opaque".

Consider the following small example:

import _Differentiation
import Darwin

@differentiable(reverse)
func f(_ x: Float) -> Float {
  if (x > 0) {
    return sin(x) * cos(x)
  } else {
    return sin(x) + cos(x)
  }
}

@inline(never)
func foo() -> Float {
  gradient(at: Float(4), of: f(x) )
}

We're currently generating:

// foo()
sil hidden [noinline] @$s6sincos3fooSfyF : $@convention(thin) () -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
bb0:
  %0 = float_literal $Builtin.FPIEEE32, 0x40800000 // 4 // users: %4, %8, %12, %1
  %1 = struct $Float (%0 : $Builtin.FPIEEE32)     // users: %25, %23, %15, %11, %2
  debug_value %1 : $Float, let, name "x", argno 1 // id: %2
  %3 = float_literal $Builtin.FPIEEE32, 0x0 // 0  // user: %4
  %4 = builtin "fcmp_olt_FPIEEE32"(%3 : $Builtin.FPIEEE32, %0 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %6
  %5 = tuple ()                                   // users: %21, %7
  cond_br %4, bb1, bb2                            // id: %6

bb1:                                              // Preds: bb0
  %7 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %5 : $() // user: %18
  %8 = builtin "int_sin_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %9
  %9 = struct $Float (%8 : $Builtin.FPIEEE32)     // user: %17
  // function_ref closure #1 in _vjpSin(_:)
  %10 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %11
  %11 = partial_apply [callee_guaranteed] %10(%1) : $@convention(thin) (Float, Float) -> Float // user: %18
  %12 = builtin "int_cos_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %13
  %13 = struct $Float (%12 : $Builtin.FPIEEE32)   // user: %17
  // function_ref closure #1 in _vjpCos(_:)
  %14 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %15
  %15 = partial_apply [callee_guaranteed] %14(%1) : $@convention(thin) (Float, Float) -> Float // user: %18
  // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
  %16 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %17
  %17 = partial_apply [callee_guaranteed] %16(%13, %9) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %18
  %18 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%7, %11, %15, %17) // user: %19
  %19 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %18 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %20
  br bb3(%19 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %20

bb2:                                              // Preds: bb0
  %21 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %5 : $() // user: %28
  // function_ref closure #1 in _vjpSin(_:)
  %22 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %23
  %23 = partial_apply [callee_guaranteed] %22(%1) : $@convention(thin) (Float, Float) -> Float // user: %28
  // function_ref closure #1 in _vjpCos(_:)
  %24 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %25
  %25 = partial_apply [callee_guaranteed] %24(%1) : $@convention(thin) (Float, Float) -> Float // user: %28
  // function_ref closure #1 in static Float._vjpAdd(lhs:rhs:)
  %26 = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) // user: %27
  %27 = thin_to_thick_function %26 : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) // user: %28
  %28 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%21, %23, %25, %27) // user: %29
  %29 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %28 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %30
  br bb3(%29 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %30

// %31                                            // user: %36
bb3(%31 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
  // function_ref pullback of f(_:)
  %32 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %36
  %33 = integer_literal $Builtin.Int64, 1         // user: %34
  %34 = builtin "sitofp_Int64_FPIEEE32"(%33 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %35
  %35 = struct $Float (%34 : $Builtin.FPIEEE32)   // user: %36
  %36 = apply %32(%35, %31) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
  return %36 : $Float                             // id: %37
} // end sil function '$s6sincos3fooSfyF'

// pullback of f(_:)
sil private @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float {
[%0: read v**.c*.v**, write v**.c*.v**, copy v**.c*.v**, destroy v**.c*.v**]
[%1: noescape v**, read v**.c*.v**, write v**.c*.v**, copy v**.c*.v**, destroy v**.c*.v**]
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0                                             // users: %35, %10
// %1                                             // user: %5
bb0(%0 : $Float, %1 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0):
  %2 = integer_literal $Builtin.Int64, 0          // user: %3
  %3 = builtin "sitofp_Int64_FPIEEE32"(%2 : $Builtin.Int64) : $Builtin.FPIEEE32 // users: %43, %40, %18, %15, %4, %48, %23
  debug_value %3 : $Builtin.FPIEEE32, let, name "x", argno 1, type $Float, expr op_fragment:#Float._value // id: %4
  switch_enum %1 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, case #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb2 // id: %5

// %6                                             // users: %9, %8, %7
bb1(%6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float))): // Preds: bb0
  %7 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 1 // users: %25, %24
  %8 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 2 // users: %21, %20
  %9 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 3 // users: %11, %10
  %10 = apply %9(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %13, %12
  strong_release %9 : $@callee_guaranteed (Float) -> (Float, Float) // id: %11
  %12 = tuple_extract %10 : $(Float, Float), 0    // user: %14
  %13 = tuple_extract %10 : $(Float, Float), 1    // user: %17
  %14 = struct_extract %12 : $Float, #Float._value // user: %15
  %15 = builtin "fadd_FPIEEE32"(%14 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %16
  %16 = struct $Float (%15 : $Builtin.FPIEEE32)   // user: %24
  %17 = struct_extract %13 : $Float, #Float._value // user: %18
  %18 = builtin "fadd_FPIEEE32"(%17 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %19
  %19 = struct $Float (%18 : $Builtin.FPIEEE32)   // user: %20
  %20 = apply %8(%19) : $@callee_guaranteed (Float) -> Float // user: %22
  strong_release %8 : $@callee_guaranteed (Float) -> Float // id: %21
  %22 = struct_extract %20 : $Float, #Float._value // user: %23
  %23 = builtin "fadd_FPIEEE32"(%22 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %27
  %24 = apply %7(%16) : $@callee_guaranteed (Float) -> Float // user: %26
  strong_release %7 : $@callee_guaranteed (Float) -> Float // id: %25
  %26 = struct_extract %24 : $Float, #Float._value // user: %27
  %27 = builtin "fadd_FPIEEE32"(%26 : $Builtin.FPIEEE32, %23 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %28
  %28 = struct $Float (%27 : $Builtin.FPIEEE32)   // users: %30, %29
  debug_value %28 : $Float, let, name "x", argno 1 // id: %29
  br bb3(%28 : $Float)                            // id: %30

// %31                                            // users: %34, %33, %32
bb2(%31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float))): // Preds: bb0
  %32 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 1 // users: %50, %49
  %33 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 2 // users: %46, %45
  %34 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 3 // users: %36, %35
  %35 = apply %34(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %38, %37
  strong_release %34 : $@callee_guaranteed (Float) -> (Float, Float) // id: %36
  %37 = tuple_extract %35 : $(Float, Float), 0    // user: %39
  %38 = tuple_extract %35 : $(Float, Float), 1    // user: %42
  %39 = struct_extract %37 : $Float, #Float._value // user: %40
  %40 = builtin "fadd_FPIEEE32"(%39 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %41
  %41 = struct $Float (%40 : $Builtin.FPIEEE32)   // user: %49
  %42 = struct_extract %38 : $Float, #Float._value // user: %43
  %43 = builtin "fadd_FPIEEE32"(%42 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %44
  %44 = struct $Float (%43 : $Builtin.FPIEEE32)   // user: %45
  %45 = apply %33(%44) : $@callee_guaranteed (Float) -> Float // user: %47
  strong_release %33 : $@callee_guaranteed (Float) -> Float // id: %46
  %47 = struct_extract %45 : $Float, #Float._value // user: %48
  %48 = builtin "fadd_FPIEEE32"(%47 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %52
  %49 = apply %32(%41) : $@callee_guaranteed (Float) -> Float // user: %51
  strong_release %32 : $@callee_guaranteed (Float) -> Float // id: %50
  %51 = struct_extract %49 : $Float, #Float._value // user: %52
  %52 = builtin "fadd_FPIEEE32"(%51 : $Builtin.FPIEEE32, %48 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %53
  %53 = struct $Float (%52 : $Builtin.FPIEEE32)   // users: %55, %54
  debug_value %53 : $Float, let, name "x", argno 1 // id: %54
  br bb3(%53 : $Float)                            // id: %55

// %56                                            // users: %58, %57
bb3(%56 : $Float):                                // Preds: bb1 bb2
  debug_value %56 : $Float, let, name "x", argno 1 // id: %57
  return %56 : $Float                             // id: %58
} // end sil function '$s6sincos1fyS2fFTJpSpSr'

It would be great to find ways to simplify this further.

One immediate thing is to improve constant propagation: obviously the condition in the function is already true (though, this optimization is done further on LLVM IR level, however, we still have 3 closures there and therefore – 3 context allocations for).

Though, this certainly won't help the common case:

@inline(never)
func foo(_ x : Float) -> Float {
  gradient(at: x, of: f)
}

It looks like the apply of pullback closure was correctly turned into apply of the pullback itself. However, it was not inlined further. It would be interesting to understand, why. And even if the pullback would be inlined, can we do some kind of CSE and optimize the diamond-shape CFG hoisting the code into the corresponding BBs?

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions