Skip to content

[5.10][AutoDiff] Fix linear map tuple types computation #69499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
heapAllocatedContext = true;
decl->setInterfaceType(astCtx.TheRawPointerType);
} else { // Otherwise the payload is the linear map tuple.
auto linearMapStructTy = getLinearMapTupleType(predBB)->getCanonicalType();
auto *linearMapStructTy = getLinearMapTupleType(predBB);
assert(linearMapStructTy && "must have linear map struct type for predecessor BB");
auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();
decl->setInterfaceType(
linearMapStructTy->hasArchetype()
? linearMapStructTy->mapTypeOutOfContext() : linearMapStructTy);
canLinearMapStructTy->hasArchetype()
? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy);
}
// Create enum element and enum case declarations.
auto *paramList = ParameterList::create(astCtx, {decl});
Expand Down Expand Up @@ -331,10 +333,28 @@ void LinearMapInfo::generateDifferentiationDataStructures(
}

// Add linear map fields to the linear map tuples.
for (auto &origBB : *original) {
//
// Now we need to be very careful as we're having a very subtle
// chicken-and-egg problem. We need lowered branch trace enum type for the
// linear map typle type. However branch trace enum type lowering depends on
// the lowering of its elements (at very least, the type classification of
// being trivial / non-trivial). As the lowering is cached we need to ensure
// we compute lowered type for the branch trace enum when the corresponding
// EnumDecl is fully complete: we cannot add more entries without causing some
// very subtle issues later on. However, the elements of the enum are linear
// map tuples of predecessors, that correspondingly may contain branch trace
// enums of corresponding predecessor BBs.
//
// Traverse all BBs in reverse post-order traversal order to ensure we process
// each BB before its predecessors.
llvm::ReversePostOrderTraversal<SILFunction *> RPOT(original);
for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) {
auto *origBB = *Iter;
SmallVector<TupleTypeElt, 4> linearTupleTypes;
if (!origBB.isEntry()) {
CanType traceEnumType = getBranchingTraceEnumLoweredType(&origBB).getASTType();
if (!origBB->isEntry()) {
populateBranchingTraceDecl(origBB, loopInfo);

CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType();
linearTupleTypes.emplace_back(traceEnumType,
astCtx.getIdentifier(traceEnumFieldName));
}
Expand All @@ -343,7 +363,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
// Do not add linear map fields for semantic member accessors, which have
// special-case pullback generation. Linear map tuples should be empty.
} else {
for (auto &inst : origBB) {
for (auto &inst : *origBB) {
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
// Add linear map field to struct for active `apply` instructions.
// Skip array literal intrinsic applications since array literal
Expand All @@ -363,12 +383,9 @@ void LinearMapInfo::generateDifferentiationDataStructures(
}
}

linearMapTuples.insert({&origBB, TupleType::get(linearTupleTypes, astCtx)});
linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)});
}

for (auto &origBB : *original)
populateBranchingTraceDecl(&origBB, loopInfo);

// Print generated linear map structs and branching trace enums.
// These declarations do not show up with `-emit-sil` because they are
// implicit. Instead, use `-Xllvm -debug-only=differentiation` to test
Expand Down
39 changes: 36 additions & 3 deletions test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ func cond(_ x: Float) -> Float {
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt, [[BB2_PB_STRUCT]]
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)

// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @condTJpSpSr
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]])
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL: return [[VJP_RESULT]]


// CHECK-SIL-LABEL: sil private [ossa] @condTJpSpSr : $@convention(thin) (Float, @owned _AD__cond_bb3__Pred__src_0_wrt_0) -> Float {
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0):
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0):
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3

// CHECK-SIL: bb1([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $(predecessor: _AD__cond_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float))):
Expand Down Expand Up @@ -132,6 +132,39 @@ func loop_generic<T : Differentiable & FloatingPoint>(_ x: T) -> T {
return result
}

@differentiable(reverse)
@_silgen_name("loop_context")
func loop_context(x: Float) -> Float {
let y = x + 1
for _ in 0 ..< 1 {}
return y
}

// CHECK-DATA-STRUCTURES-LABEL: Generated linear map tuples and branching trace enums for @loop_context:
// CHECK-DATA-STRUCTURES: (_: (Float) -> Float)
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb2__Pred__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb3__Pred__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb0__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb1__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb2(Builtin.RawPointer)
// CHECK-DATA-STRUCTURES: case bb0((_: (Float) -> Float))
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb2__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb3__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer)
// CHECK-DATA-STRUCTURES: }

// CHECK-SIL-LABEL: sil private [ossa] @loop_contextTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> Float {
// CHECK-SIL: bb1([[LOOP_CONTEXT:%.*]] : $Builtin.RawPointer):
// CHECK-SIL: [[PB_TUPLE_ADDR:%.*]] = pointer_to_address [[LOOP_CONTEXT]] : $Builtin.RawPointer to [strict] $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
// CHECK-SIL: [[PB_TUPLE_CPY:%.*]] = load [copy] [[PB_TUPLE_ADDR]] : $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
// CHECK-SIL: br bb3({{.*}} : $Float, {{.*}} : $Float, [[PB_TUPLE_CPY]] : $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0))
// CHECK-SIL: bb3({{.*}} : $Float, {{.*}} : $Float, {{.*}} : @owned $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)):

// Test `switch_enum`.

enum Enum {
Expand Down Expand Up @@ -164,7 +197,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float {
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb2!enumelt, [[BB2_PB_STRUCT]] : $(predecessor: _AD__enum_notactive_bb2__Pred__src_0_wrt_1, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float)
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)

// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @enum_notactiveTJpUSpSr
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]])
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
Expand Down