Skip to content

Commit bf38dfa

Browse files
committed
Fix quite subtle but nasty bug in linear map tuple types computation:
we need lowered type for branch trace enum in order to compute linear map tuple type. However, the lowering of branch trace enum type depends on the types of its elements (the payloads are linear map tuples of predecessor BB). As lowered types are cached, we cannot populate branch trace enum entries in the end as we did before: we already used wrong lowered types for linear map tuples. Traverse basic blocks in reverse post-order traver order building linear map tuples and branch tracing enumns in one go, ensuring that we've done with predecessor BBs before processing the BB itself.
1 parent 118fd86 commit bf38dfa

File tree

2 files changed

+60
-14
lines changed

2 files changed

+60
-14
lines changed

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,12 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
142142
heapAllocatedContext = true;
143143
decl->setInterfaceType(astCtx.TheRawPointerType);
144144
} else { // Otherwise the payload is the linear map tuple.
145-
auto linearMapStructTy = getLinearMapTupleType(predBB)->getCanonicalType();
145+
auto *linearMapStructTy = getLinearMapTupleType(predBB);
146+
assert(linearMapStructTy && "must have linear map struct type for predecessor BB");
147+
auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();
146148
decl->setInterfaceType(
147-
linearMapStructTy->hasArchetype()
148-
? linearMapStructTy->mapTypeOutOfContext() : linearMapStructTy);
149+
canLinearMapStructTy->hasArchetype()
150+
? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy);
149151
}
150152
// Create enum element and enum case declarations.
151153
auto *paramList = ParameterList::create(astCtx, {decl});
@@ -331,10 +333,24 @@ void LinearMapInfo::generateDifferentiationDataStructures(
331333
}
332334

333335
// Add linear map fields to the linear map tuples.
334-
for (auto &origBB : *original) {
336+
// Now we need to be very careful as we're having a very subtle
337+
// chicken-and-egg problem. We need lowered branch trace enum type for the
338+
// linear map typle type. However its type lowering depends on its elements
339+
// (at very least, the type classification od being trivial / non-trivial). As
340+
// the lowering is cached we need to ensure we compute lowered type for branch
341+
// trace enum when the corresponding EnumDecl is complete: we cannot add more
342+
// entries without causing some very subtle issues later one. However, the
343+
// elements of the enum are linear map tuples of predecessors. Traverse all
344+
// BBs in reverse post-order traversal order to ensure we process each BB
345+
// before its predecessors.
346+
llvm::ReversePostOrderTraversal<SILFunction *> RPOT(original);
347+
for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) {
348+
auto *origBB = *Iter;
335349
SmallVector<TupleTypeElt, 4> linearTupleTypes;
336-
if (!origBB.isEntry()) {
337-
CanType traceEnumType = getBranchingTraceEnumLoweredType(&origBB).getASTType();
350+
if (!origBB->isEntry()) {
351+
populateBranchingTraceDecl(origBB, loopInfo);
352+
353+
CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType();
338354
linearTupleTypes.emplace_back(traceEnumType,
339355
astCtx.getIdentifier(traceEnumFieldName));
340356
}
@@ -343,7 +359,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
343359
// Do not add linear map fields for semantic member accessors, which have
344360
// special-case pullback generation. Linear map tuples should be empty.
345361
} else {
346-
for (auto &inst : origBB) {
362+
for (auto &inst : *origBB) {
347363
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
348364
// Add linear map field to struct for active `apply` instructions.
349365
// Skip array literal intrinsic applications since array literal
@@ -363,12 +379,9 @@ void LinearMapInfo::generateDifferentiationDataStructures(
363379
}
364380
}
365381

366-
linearMapTuples.insert({&origBB, TupleType::get(linearTupleTypes, astCtx)});
382+
linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)});
367383
}
368384

369-
for (auto &origBB : *original)
370-
populateBranchingTraceDecl(&origBB, loopInfo);
371-
372385
// Print generated linear map structs and branching trace enums.
373386
// These declarations do not show up with `-emit-sil` because they are
374387
// implicit. Instead, use `-Xllvm -debug-only=differentiation` to test

test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ func cond(_ x: Float) -> Float {
5656
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt, [[BB2_PB_STRUCT]]
5757
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
5858

59-
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
59+
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0)
6060
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @condTJpSpSr
6161
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]])
6262
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
6363
// CHECK-SIL: return [[VJP_RESULT]]
6464

6565

6666
// CHECK-SIL-LABEL: sil private [ossa] @condTJpSpSr : $@convention(thin) (Float, @owned _AD__cond_bb3__Pred__src_0_wrt_0) -> Float {
67-
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0):
67+
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0):
6868
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3
6969

7070
// CHECK-SIL: bb1([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $(predecessor: _AD__cond_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float))):
@@ -132,6 +132,39 @@ func loop_generic<T : Differentiable & FloatingPoint>(_ x: T) -> T {
132132
return result
133133
}
134134

135+
@differentiable(reverse)
136+
@_silgen_name("loop_context")
137+
func loop_context(x: Float) -> Float {
138+
let y = x + 1
139+
for _ in 0 ..< 1 {}
140+
return y
141+
}
142+
143+
// CHECK-DATA-STRUCTURES-LABEL: Generated linear map tuples and branching trace enums for @loop_context:
144+
// CHECK-DATA-STRUCTURES: (_: (Float) -> Float)
145+
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
146+
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb2__Pred__src_0_wrt_0)
147+
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb3__Pred__src_0_wrt_0)
148+
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb0__Pred__src_0_wrt_0 {
149+
// CHECK-DATA-STRUCTURES: }
150+
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb1__Pred__src_0_wrt_0 {
151+
// CHECK-DATA-STRUCTURES: case bb2(Builtin.RawPointer)
152+
// CHECK-DATA-STRUCTURES: case bb0((_: (Float) -> Float))
153+
// CHECK-DATA-STRUCTURES: }
154+
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb2__Pred__src_0_wrt_0 {
155+
// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer)
156+
// CHECK-DATA-STRUCTURES: }
157+
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb3__Pred__src_0_wrt_0 {
158+
// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer)
159+
// CHECK-DATA-STRUCTURES: }
160+
161+
// CHECK-SIL-LABEL: sil private [ossa] @loop_contextTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> Float {
162+
// CHECK-SIL: bb1([[LOOP_CONTEXT:%.*]] : $Builtin.RawPointer):
163+
// CHECK-SIL: [[PB_TUPLE_ADDR:%.*]] = pointer_to_address [[LOOP_CONTEXT]] : $Builtin.RawPointer to [strict] $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
164+
// CHECK-SIL: [[PB_TUPLE_CPY:%.*]] = load [copy] [[PB_TUPLE_ADDR]] : $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
165+
// CHECK-SIL: br bb3({{.*}} : $Float, {{.*}} : $Float, [[PB_TUPLE_CPY]] : $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0))
166+
// CHECK-SIL: bb3({{.*}} : $Float, {{.*}} : $Float, {{.*}} : @owned $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)):
167+
135168
// Test `switch_enum`.
136169

137170
enum Enum {
@@ -164,7 +197,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float {
164197
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb2!enumelt, [[BB2_PB_STRUCT]] : $(predecessor: _AD__enum_notactive_bb2__Pred__src_0_wrt_1, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float)
165198
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
166199

167-
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
200+
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
168201
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @enum_notactiveTJpUSpSr
169202
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]])
170203
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)

0 commit comments

Comments
 (0)