Description
Recent autodiff changes that removed use of pullback structs in favor of direct pullback application to nested calls enabled lots of possible inlining and specialization opportunities in VJPs / pullbacks.
However, the situation is not that great for the functions with control flow as generated code with enums / tuples effectively prevents inlining and function substitution as forward path becomes quite "opaque".
Consider the following small example:
import _Differentiation
import Darwin
@differentiable(reverse)
func f(_ x: Float) -> Float {
if (x > 0) {
return sin(x) * cos(x)
} else {
return sin(x) + cos(x)
}
}
@inline(never)
func foo() -> Float {
gradient(at: Float(4), of: f(x) )
}
We're currently generating:
// foo()
sil hidden [noinline] @$s6sincos3fooSfyF : $@convention(thin) () -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
bb0:
%0 = float_literal $Builtin.FPIEEE32, 0x40800000 // 4 // users: %4, %8, %12, %1
%1 = struct $Float (%0 : $Builtin.FPIEEE32) // users: %25, %23, %15, %11, %2
debug_value %1 : $Float, let, name "x", argno 1 // id: %2
%3 = float_literal $Builtin.FPIEEE32, 0x0 // 0 // user: %4
%4 = builtin "fcmp_olt_FPIEEE32"(%3 : $Builtin.FPIEEE32, %0 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %6
%5 = tuple () // users: %21, %7
cond_br %4, bb1, bb2 // id: %6
bb1: // Preds: bb0
%7 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %5 : $() // user: %18
%8 = builtin "int_sin_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %9
%9 = struct $Float (%8 : $Builtin.FPIEEE32) // user: %17
// function_ref closure #1 in _vjpSin(_:)
%10 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %11
%11 = partial_apply [callee_guaranteed] %10(%1) : $@convention(thin) (Float, Float) -> Float // user: %18
%12 = builtin "int_cos_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %13
%13 = struct $Float (%12 : $Builtin.FPIEEE32) // user: %17
// function_ref closure #1 in _vjpCos(_:)
%14 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %15
%15 = partial_apply [callee_guaranteed] %14(%1) : $@convention(thin) (Float, Float) -> Float // user: %18
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%16 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %17
%17 = partial_apply [callee_guaranteed] %16(%13, %9) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %18
%18 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%7, %11, %15, %17) // user: %19
%19 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %18 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %20
br bb3(%19 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %20
bb2: // Preds: bb0
%21 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %5 : $() // user: %28
// function_ref closure #1 in _vjpSin(_:)
%22 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %23
%23 = partial_apply [callee_guaranteed] %22(%1) : $@convention(thin) (Float, Float) -> Float // user: %28
// function_ref closure #1 in _vjpCos(_:)
%24 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %25
%25 = partial_apply [callee_guaranteed] %24(%1) : $@convention(thin) (Float, Float) -> Float // user: %28
// function_ref closure #1 in static Float._vjpAdd(lhs:rhs:)
%26 = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) // user: %27
%27 = thin_to_thick_function %26 : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) // user: %28
%28 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%21, %23, %25, %27) // user: %29
%29 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %28 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %30
br bb3(%29 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %30
// %31 // user: %36
bb3(%31 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
// function_ref pullback of f(_:)
%32 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %36
%33 = integer_literal $Builtin.Int64, 1 // user: %34
%34 = builtin "sitofp_Int64_FPIEEE32"(%33 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %35
%35 = struct $Float (%34 : $Builtin.FPIEEE32) // user: %36
%36 = apply %32(%35, %31) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
return %36 : $Float // id: %37
} // end sil function '$s6sincos3fooSfyF'
// pullback of f(_:)
sil private @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float {
[%0: read v**.c*.v**, write v**.c*.v**, copy v**.c*.v**, destroy v**.c*.v**]
[%1: noescape v**, read v**.c*.v**, write v**.c*.v**, copy v**.c*.v**, destroy v**.c*.v**]
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0 // users: %35, %10
// %1 // user: %5
bb0(%0 : $Float, %1 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0):
%2 = integer_literal $Builtin.Int64, 0 // user: %3
%3 = builtin "sitofp_Int64_FPIEEE32"(%2 : $Builtin.Int64) : $Builtin.FPIEEE32 // users: %43, %40, %18, %15, %4, %48, %23
debug_value %3 : $Builtin.FPIEEE32, let, name "x", argno 1, type $Float, expr op_fragment:#Float._value // id: %4
switch_enum %1 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, case #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb2 // id: %5
// %6 // users: %9, %8, %7
bb1(%6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float))): // Preds: bb0
%7 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 1 // users: %25, %24
%8 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 2 // users: %21, %20
%9 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 3 // users: %11, %10
%10 = apply %9(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %13, %12
strong_release %9 : $@callee_guaranteed (Float) -> (Float, Float) // id: %11
%12 = tuple_extract %10 : $(Float, Float), 0 // user: %14
%13 = tuple_extract %10 : $(Float, Float), 1 // user: %17
%14 = struct_extract %12 : $Float, #Float._value // user: %15
%15 = builtin "fadd_FPIEEE32"(%14 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %16
%16 = struct $Float (%15 : $Builtin.FPIEEE32) // user: %24
%17 = struct_extract %13 : $Float, #Float._value // user: %18
%18 = builtin "fadd_FPIEEE32"(%17 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %19
%19 = struct $Float (%18 : $Builtin.FPIEEE32) // user: %20
%20 = apply %8(%19) : $@callee_guaranteed (Float) -> Float // user: %22
strong_release %8 : $@callee_guaranteed (Float) -> Float // id: %21
%22 = struct_extract %20 : $Float, #Float._value // user: %23
%23 = builtin "fadd_FPIEEE32"(%22 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %27
%24 = apply %7(%16) : $@callee_guaranteed (Float) -> Float // user: %26
strong_release %7 : $@callee_guaranteed (Float) -> Float // id: %25
%26 = struct_extract %24 : $Float, #Float._value // user: %27
%27 = builtin "fadd_FPIEEE32"(%26 : $Builtin.FPIEEE32, %23 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %28
%28 = struct $Float (%27 : $Builtin.FPIEEE32) // users: %30, %29
debug_value %28 : $Float, let, name "x", argno 1 // id: %29
br bb3(%28 : $Float) // id: %30
// %31 // users: %34, %33, %32
bb2(%31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float))): // Preds: bb0
%32 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 1 // users: %50, %49
%33 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 2 // users: %46, %45
%34 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 3 // users: %36, %35
%35 = apply %34(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %38, %37
strong_release %34 : $@callee_guaranteed (Float) -> (Float, Float) // id: %36
%37 = tuple_extract %35 : $(Float, Float), 0 // user: %39
%38 = tuple_extract %35 : $(Float, Float), 1 // user: %42
%39 = struct_extract %37 : $Float, #Float._value // user: %40
%40 = builtin "fadd_FPIEEE32"(%39 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %41
%41 = struct $Float (%40 : $Builtin.FPIEEE32) // user: %49
%42 = struct_extract %38 : $Float, #Float._value // user: %43
%43 = builtin "fadd_FPIEEE32"(%42 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %44
%44 = struct $Float (%43 : $Builtin.FPIEEE32) // user: %45
%45 = apply %33(%44) : $@callee_guaranteed (Float) -> Float // user: %47
strong_release %33 : $@callee_guaranteed (Float) -> Float // id: %46
%47 = struct_extract %45 : $Float, #Float._value // user: %48
%48 = builtin "fadd_FPIEEE32"(%47 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %52
%49 = apply %32(%41) : $@callee_guaranteed (Float) -> Float // user: %51
strong_release %32 : $@callee_guaranteed (Float) -> Float // id: %50
%51 = struct_extract %49 : $Float, #Float._value // user: %52
%52 = builtin "fadd_FPIEEE32"(%51 : $Builtin.FPIEEE32, %48 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %53
%53 = struct $Float (%52 : $Builtin.FPIEEE32) // users: %55, %54
debug_value %53 : $Float, let, name "x", argno 1 // id: %54
br bb3(%53 : $Float) // id: %55
// %56 // users: %58, %57
bb3(%56 : $Float): // Preds: bb1 bb2
debug_value %56 : $Float, let, name "x", argno 1 // id: %57
return %56 : $Float // id: %58
} // end sil function '$s6sincos1fyS2fFTJpSpSr'
It would be great to find ways to simplify this further.
One immediate thing is to improve constant propagation: obviously the condition in the function is already true (though, this optimization is done further on LLVM IR level, however, we still have 3 closures there and therefore – 3 context allocations for).
Though, this certainly won't help the common case:
@inline(never)
func foo(_ x : Float) -> Float {
gradient(at: x, of: f)
}
It looks like the apply
of pullback closure was correctly turned into apply of the pullback itself. However, it was not inlined further. It would be interesting to understand, why. And even if the pullback would be inlined, can we do some kind of CSE and optimize the diamond-shape CFG hoisting the code into the corresponding BBs?