From 1883ac162441c90d294aebbd076206a9fb1b9811 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 19 Aug 2019 09:28:45 -0700 Subject: [PATCH] Make changes for ownership change. --- .../Mandatory/Differentiation.cpp | 182 ++++++++++++------ test/AutoDiff/forward_mode_sil.swift | 45 ++--- 2 files changed, 142 insertions(+), 85 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 836a3a7dda87d..d64090c458a05 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -4005,6 +4005,12 @@ class JVPEmitter final /// Mapping from differential basic blocks to differential struct arguments. DenseMap differentialStructArguments; + /// Mapping from differential struct field declarations to differential struct + /// elements destructured from the linear map basic block argument. In the + /// beginning of each differential basic block, the block's differential struct is + /// destructured into individual elements stored here. + DenseMap differentialStructElements; + /// Mapping from original basic blocks and original values to corresponding /// tangent values. DenseMap tangentValueMap; @@ -4079,6 +4085,35 @@ class JVPEmitter final return SILBuilder(*differential); } + //--------------------------------------------------------------------------// + // Differential struct mapping + //--------------------------------------------------------------------------// + + void initializeDifferentialStructElements(SILBasicBlock *origBB, + SILInstructionResultArray values) { + auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB); + assert(diffStructDecl->getStoredProperties().size() == values.size() && + "The number of differential struct fields must equal the number of " + "differential struct element values"); + for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) { + assert( + std::get<1>(pair).getOwnershipKind() != ValueOwnershipKind::Guaranteed + && "Differential struct elements must be @owned"); + auto insertion = differentialStructElements.insert({std::get<0>(pair), + std::get<1>(pair)}); + (void)insertion; + assert(insertion.second && "A differential struct element already exists!"); + } + } + + SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field) { + assert(differentialInfo.getLinearMapStruct(origBB) == + cast(field->getDeclContext())); + assert(differentialStructElements.count(field) && + "Differential struct element for this field does not exist!"); + return differentialStructElements.lookup(field); + } + //--------------------------------------------------------------------------// // General utilities //--------------------------------------------------------------------------// @@ -4209,8 +4244,8 @@ class JVPEmitter final type, ResilienceExpansion::Minimal); auto *buffer = diffBuilder.createAllocStack(loc, silType); emitZeroIndirect(type, buffer, loc); - auto *loaded = diffBuilder.createLoad( - loc, buffer, LoadOwnershipQualifier::Unqualified); + auto loaded = diffBuilder.emitLoadValueOperation( + loc, buffer, LoadOwnershipQualifier::Take); diffBuilder.createDeallocStack(loc, buffer); return loaded; } @@ -4256,38 +4291,12 @@ class JVPEmitter final } SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { - auto diffBuilder = getDifferentialBuilder(); assert(originalBuffer->getType().isAddress()); assert(originalBuffer->getFunction() == original); auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, SILValue()); - if (!insertion.second) // not inserted. - return insertion.first->getSecond(); - - // Set insertion point for local allocation builder: before the last local - // allocation, or at the start of the tangent function's entry if no local - // allocations exist yet. - diffLocalAllocBuilder.setInsertionPoint( - getDifferential().getEntryBlock(), - getNextDifferentialLocalAllocationInsertionPoint()); - - // Allocate local buffer and initialize to zero. - auto bufObjectType = getRemappedTangentType(originalBuffer->getType()); - auto *newBuf = diffLocalAllocBuilder.createAllocStack( - originalBuffer.getLoc(), bufObjectType); - - // Temporarily change global builder insertion point and emit zero into the - // local buffer. - auto insertionPoint = diffLocalAllocBuilder.getInsertionBB(); - diffBuilder.setInsertionPoint( - diffLocalAllocBuilder.getInsertionBB(), - diffLocalAllocBuilder.getInsertionPoint()); - emitZeroIndirect(bufObjectType.getASTType(), newBuf, newBuf->getLoc()); - diffBuilder.setInsertionPoint(insertionPoint); - - // Create cleanup for local buffer. - differentialLocalAllocations.push_back(newBuf); - return (insertion.first->getSecond() = newBuf); + assert(!insertion.second && "tangent buffer should already exist"); + return insertion.first->getSecond(); } //--------------------------------------------------------------------------// @@ -4345,6 +4354,38 @@ class JVPEmitter final // Tangent emission helpers //--------------------------------------------------------------------------// + void emitTangentForDestroyValueInst(DestroyValueInst *dvi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = dvi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); + diffBuilder.emitDestroyValue(loc, tanVal); + } + + void emitTangentForBeginBorrow(BeginBorrowInst *bbi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = bbi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); + auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); + setTangentValue(bbi->getParent(), bbi, + makeConcreteTangentValue(tanValBorrow)); + } + + void emitTangentForEndBorrow(EndBorrowInst *ebi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = ebi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); + diffBuilder.emitEndBorrowOperation(loc, tanVal); + } + + void emitTangentForCopyValueInst(CopyValueInst *cvi) { + auto &diffBuilder = getDifferentialBuilder(); + auto tan = getTangentValue(cvi->getOperand()); + auto tanVal = materializeTangent(tan, cvi->getLoc()); + auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal); + setTangentValue(cvi->getParent(), cvi, + makeConcreteTangentValue(tanValCopy)); + } + void emitTangentForReturnInst(ReturnInst *ri) { auto loc = ri->getOperand().getLoc(); auto diffBuilder = getDifferentialBuilder(); @@ -4366,8 +4407,7 @@ class JVPEmitter final // Get the differential. auto *field = differentialInfo.lookUpLinearMapDecl(ai); assert(field); - SILValue differential = diffBuilder.createStructExtract( - loc, getDifferentialStructArgument(ai->getParent()), field); + SILValue differential = getDifferentialStructElement(bb, field); SmallVector diffArgs; for (auto origArg : ai->getArguments()) { @@ -4400,6 +4440,7 @@ class JVPEmitter final auto *differentialCall = diffBuilder.createApply( loc, differential, SubstitutionMap(), diffArgs, /*isNonThrowing*/ false); + diffBuilder.emitDestroyValueOperation(loc, differential); assert(differentialCall->getNumResults() == 1 && "Expected differential to return one result"); @@ -4423,6 +4464,8 @@ class JVPEmitter final } void startDifferentialGeneration() { + auto &diffBuilder = getDifferentialBuilder(); + // Create differential blocks and arguments. // TODO: Consider visiting original blocks in pre-order (dominance) order. SmallVector preOrderDomOrder; @@ -4439,9 +4482,9 @@ class JVPEmitter final if (&origBB == origEntry) { assert(diffBB->isEntry()); createEntryArguments(&differential); - auto *lastArg = diffBB->getArguments().back(); - assert(lastArg->getType() == diffStructLoweredType); - differentialStructArguments[&origBB] = lastArg; + auto *mainDifferentialStruct = diffBB->getArguments().back(); + assert(mainDifferentialStruct->getType() == diffStructLoweredType); + differentialStructArguments[&origBB] = mainDifferentialStruct; } LLVM_DEBUG({ @@ -4460,7 +4503,6 @@ class JVPEmitter final // The differential function has type: // (arg0', ..., argn', exit_diffs) -> result'. - auto &diffBuilder = getDifferentialBuilder(); auto diffParamArgs = differential.getArgumentsWithoutIndirectResults().drop_back(); assert(diffParamArgs.size() == attr->getIndices().parameters->getCapacity()); @@ -4484,15 +4526,12 @@ class JVPEmitter final } auto *diffEntry = getDifferential().getEntryBlock(); - auto diffLoc = getDifferential().getLocation(); diffBuilder.setInsertionPoint( diffEntry, getNextDifferentialLocalAllocationInsertionPoint()); for (auto index : *getIndices().parameters) { auto diffParam = diffParamArgs[index]; auto origParam = origParamArgs[index]; - diffBuilder.createRetainValue(diffLoc, diffParam, - diffBuilder.getDefaultAtomicity()); setTangentValue(origEntry, origParam, makeConcreteTangentValue(diffParam)); LLVM_DEBUG(getADDebugStream() @@ -4571,7 +4610,7 @@ class JVPEmitter final auto *dfStruct = linearMapInfo->getLinearMapStruct(origEntry); auto dfStructType = dfStruct->getDeclaredInterfaceType()->getCanonicalType(); - dfParams.push_back({dfStructType, ParameterConvention::Direct_Guaranteed}); + dfParams.push_back({dfStructType, ParameterConvention::Direct_Owned}); auto diffName = original->getASTContext() .getIdentifier("AD__" + original->getName().str() + @@ -4593,7 +4632,6 @@ class JVPEmitter final linkage, diffName, diffType, diffGenericEnv, original->getLocation(), original->isBare(), IsNotTransparent, original->isSerialized(), original->isDynamicallyReplaceable()); - differential->setOwnershipEliminated(); differential->setDebugScope( new (module) SILDebugScope(original->getLocation(), differential)); @@ -4649,6 +4687,14 @@ class JVPEmitter final errorOccurred = true; } + /// Handle `copy_value` instruction. + /// Original: y = copy_value x + /// Adjoint: tan[x] = copy_value tan[y] + void visitCopyValueInst(CopyValueInst *cvi) { + TypeSubstCloner::visitCopyValueInst(cvi); + emitTangentForCopyValueInst(cvi); + } + void visitReturnInst(ReturnInst *ri) { auto loc = ri->getOperand().getLoc(); auto *origExit = ri->getParent(); @@ -4682,6 +4728,39 @@ class JVPEmitter final emitTangentForReturnInst(ri); } + void visitInstructionsInBlock(SILBasicBlock *bb) { + // Destructure the differential struct to get the elements. + if (bb == original->getEntryBlock()) { + auto &diffBuilder = getDifferentialBuilder(); + auto diffLoc = getDifferential().getLocation(); + auto *diffBB = diffBBMap.lookup(bb); + auto *mainDifferentialStruct = diffBB->getArguments().back(); + diffBuilder.setInsertionPoint(diffBB); + auto *dsi = diffBuilder.createDestructureStruct( + diffLoc, mainDifferentialStruct); + initializeDifferentialStructElements(bb, dsi->getResults()); + } + TypeSubstCloner::visitInstructionsInBlock(bb); + } + + void visitDestroyValueInst(DestroyValueInst *dvi) { + TypeSubstCloner::visitDestroyValueInst(dvi); + if (shouldBeDifferentiated(dvi, getIndices())) + emitTangentForDestroyValueInst(dvi); + } + + void visitBeginBorrowInst(BeginBorrowInst *bbi) { + TypeSubstCloner::visitBeginBorrowInst(bbi); + if (shouldBeDifferentiated(bbi, getIndices())) + emitTangentForBeginBorrow(bbi); + } + + void visitEndBorrowInst(EndBorrowInst *ebi) { + TypeSubstCloner::visitEndBorrowInst(ebi); + if (shouldBeDifferentiated(ebi, getIndices())) + emitTangentForEndBorrow(ebi); + } + // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its JVP. void visitApplyInst(ApplyInst *ai) { @@ -4757,7 +4836,6 @@ class JVPEmitter final auto loc = ai->getLoc(); auto &builder = getBuilder(); auto original = getOpValue(ai->getCallee()); - auto functionSource = original; SILValue jvpValue; // If functionSource is a @differentiable function, just extract it. auto originalFnTy = original->getType().castTo(); @@ -4771,9 +4849,11 @@ class JVPEmitter final return; } } + auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original); jvpValue = builder.createAutoDiffFunctionExtract( loc, AutoDiffFunctionExtractInst::Extractee::JVP, - /*differentiationOrder*/ 1, functionSource); + /*differentiationOrder*/ 1, borrowedDiffFunc); + jvpValue = builder.emitCopyValueOperation(loc, jvpValue); } // Check and diagnose non-differentiable arguments. @@ -4822,7 +4902,7 @@ class JVPEmitter final // function operand is specialized with a remapped version of same // substitution map using an argument-less `partial_apply`. if (ai->getSubstitutionMap().empty()) { - builder.createRetainValue(loc, original, builder.getDefaultAtomicity()); + original = builder.emitCopyValueOperation(loc, original); } else { auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); auto jvpPartialApply = getBuilder().createPartialApply( @@ -4863,9 +4943,7 @@ class JVPEmitter final LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall); // Release the differentiable function. - if (differentiableFunc) - builder.createReleaseValue(loc, differentiableFunc, - builder.getDefaultAtomicity()); + builder.emitDestroyValueOperation(loc, jvpValue); // Get the JVP results (original results and differential). SmallVector jvpDirectResults; @@ -5020,16 +5098,6 @@ class PullbackEmitter final : public SILInstructionVisitor { // Pullback struct mapping //--------------------------------------------------------------------------// - SILArgument *getPullbackBlockPullbackStructArgument(SILBasicBlock *origBB) { -#ifndef NDEBUG - assert(origBB->getParent() == &getOriginal()); - auto *pbStruct = pullbackStructArguments[origBB]->getType() - .getStructOrBoundGenericStruct(); - assert(pbStruct == getPullbackInfo().getLinearMapStruct(origBB)); -#endif - return pullbackStructArguments[origBB]; - } - void initializePullbackStructElements(SILBasicBlock *origBB, SILInstructionResultArray values) { auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB); diff --git a/test/AutoDiff/forward_mode_sil.swift b/test/AutoDiff/forward_mode_sil.swift index 156db84b63cef..12e529a85226f 100644 --- a/test/AutoDiff/forward_mode_sil.swift +++ b/test/AutoDiff/forward_mode_sil.swift @@ -18,40 +18,32 @@ func unary(_ x: Float) -> Float { // CHECK-DATA-STRUCTURES: enum _AD__unary_bb0__Succ__src_0_wrt_0 { // CHECK-DATA-STRUCTURES: } -// CHECK-SIL-LABEL: sil hidden @AD__unary__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__unary__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[X_ARG:%.*]] : $Float): // CHECK-SIL: [[MULT_FUNC_1:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float -// CHECK-SIL: retain_value [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) // CHECK-SIL: [[AUTODIFF_INST_1:%.*]] = autodiff_function [wrt 0 1] [order 1] [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = autodiff_function_extract [jvp] [order 1] [[AUTODIFF_INST_1]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_1:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[X_ARG]], [[X_ARG]], %3) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: release_value [[AUTODIFF_INST_1]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float -// CHECK-SIL: [[ORIG_RESULT_1:%.*]] = tuple_extract [[MULT_JVP_APPLY_TUPLE_1]] : $(Float, @callee_guaranteed (Float, Float) -> Float), 0 -// CHECK-SIL: [[MULT_DIFF_1:%.*]] = tuple_extract [[MULT_JVP_APPLY_TUPLE_1]] : $(Float, @callee_guaranteed (Float, Float) -> Float), 1 +// CHECK-SIL: ([[ORIG_RESULT_1:%.*]], [[MULT_DIFF_1:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_1]] : $(Float, @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_2:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float -// CHECK-SIL: retain_value [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) // CHECK-SIL: [[AUTODIFF_INST_2:%.*]] = autodiff_function [wrt 0 1] [order 1] [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = autodiff_function_extract [jvp] [order 1] [[AUTODIFF_INST_2]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_2:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[ORIG_RESULT_1]], [[X_ARG]], %2) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: release_value [[AUTODIFF_INST_2]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float -// CHECK-SIL: [[ORIG_RESULT_2:%.*]] = tuple_extract [[MULT_JVP_APPLY_TUPLE_2]] : $(Float, @callee_guaranteed (Float, Float) -> Float), 0 -// CHECK-SIL: [[MULT_DIFF_2:%.*]] = tuple_extract [[MULT_JVP_APPLY_TUPLE_2]] : $(Float, @callee_guaranteed (Float, Float) -> Float), 1 +// CHECK-SIL: ([[ORIG_RESULT_2:%.*]], [[MULT_DIFF_2:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_2]] : $(Float, @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[DIFF_STRUCT:%.*]] = struct $_AD__unary_bb0__DF__src_0_wrt_0 ([[MULT_DIFF_1]] : $@callee_guaranteed (Float, Float) -> Float, [[MULT_DIFF_2]] : $@callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: [[UNARY_DIFFERENTIAL:%.*]] = function_ref @AD__unary__differential_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__unary_bb0__DF__src_0_wrt_0) -> Float -// CHECK-SIL: [[PARTIAL_APP_DIFFERENTIAL:%.*]] = partial_apply [callee_guaranteed] [[UNARY_DIFFERENTIAL]]([[DIFF_STRUCT]]) : $@convention(thin) (Float, @guaranteed _AD__unary_bb0__DF__src_0_wrt_0) -> Float +// CHECK-SIL: [[UNARY_DIFFERENTIAL:%.*]] = function_ref @AD__unary__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float +// CHECK-SIL: [[PARTIAL_APP_DIFFERENTIAL:%.*]] = partial_apply [callee_guaranteed] [[UNARY_DIFFERENTIAL]]([[DIFF_STRUCT]]) : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float // CHECK-SIL: [[RESULT:%.*]] = tuple ([[ORIG_RESULT_2]] : $Float, [[PARTIAL_APP_DIFFERENTIAL]] : $@callee_guaranteed (Float) -> Float) // CHECK-SIL: return [[RESULT]] : $(Float, @callee_guaranteed (Float) -> Float) -// CHECK-SIL-LABEL: sil hidden @AD__unary__differential_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__unary_bb0__DF__src_0_wrt_0) -> Float { -// CHECK-SIL: bb0([[X_TAN:%.*]] : $Float, [[DIFF_STRUCT:%.*]] : $_AD__unary_bb0__DF__src_0_wrt_0): -// CHECK-SIL: retain_value [[X_TAN]] : $Float -// CHECK-SIL: [[MULT_DIFF_1:%.*]] = struct_extract [[DIFF_STRUCT]] : $_AD__unary_bb0__DF__src_0_wrt_0, #_AD__unary_bb0__DF__src_0_wrt_0.differential_0 +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__unary__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float { +// CHECK-SIL: bb0([[X_TAN:%.*]] : $Float, [[DIFF_STRUCT:%.*]] : @owned $_AD__unary_bb0__DF__src_0_wrt_0): +// CHECK-SIL: ([[MULT_DIFF_1:%.*]], [[MULT_DIFF_2:%.*]]) = destructure_struct %1 : $_AD__unary_bb0__DF__src_0_wrt_0 // CHECK-SIL: [[TEMP_TAN_1:%.*]] = apply [[MULT_DIFF_1]]([[X_TAN]], [[X_TAN]]) : $@callee_guaranteed (Float, Float) -> Float -// CHECK-SIL: [[MULT_DIFF_2:%.*]] = struct_extract [[DIFF_STRUCT]] : $_AD__unary_bb0__DF__src_0_wrt_0, #_AD__unary_bb0__DF__src_0_wrt_0.differential_1 // CHECK-SIL: [[TAN_RESULT:%.*]] = apply [[MULT_DIFF_2]]([[TEMP_TAN_1]], [[X_TAN]]) : $@callee_guaranteed (Float, Float) -> Float // CHECK-SIL: return [[TAN_RESULT]] : $Float @@ -71,26 +63,23 @@ func binary(x: Float, y: Float) -> Float { // CHECK-DATA-STRUCTURES: enum _AD__binary_bb0__Succ__src_0_wrt_0_1 { // CHECK-DATA-STRUCTURES: } -// CHECK-SIL-LABEL: sil hidden @AD__binary__jvp_src_0_wrt_0_1 : $@convention(thin) (Float, Float) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) { -// CHECK-SIL: bb0([[X_ARG:%.*]] : $Float, %1 : $Float): +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__binary__jvp_src_0_wrt_0_1 : $@convention(thin) (Float, Float) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) { +// CHECK-SIL: bb0([[X_ARG:%.*]] : $Float, [[Y_ARG:%.*]] : $Float): // CHECK-SIL: [[MULT_FUNC:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float -// CHECK-SIL: retain_value [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) // CHECK-SIL: [[AUTODIFF_INST:%.*]] = autodiff_function [wrt 0 1] [order 1] [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST:%.*]] = autodiff_function_extract [jvp] [order 1] [[AUTODIFF_INST]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float -// CHECK-SIL: [[MULT_JVP_APPLY_TUPLE:%.*]] = apply [[AUTODIFF_EXTRACT_INST]]([[X_ARG]], %1, %4) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: release_value [[AUTODIFF_INST]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float -// CHECK-SIL: [[ORIG_RESULT:%.*]] = tuple_extract [[MULT_JVP_APPLY_TUPLE]] : $(Float, @callee_guaranteed (Float, Float) -> Float), 0 -// CHECK-SIL: [[MULT_DIFF:%.*]] = tuple_extract [[MULT_JVP_APPLY_TUPLE]] : $(Float, @callee_guaranteed (Float, Float) -> Float), 1 +// CHECK-SIL: [[MULT_JVP_APPLY_TUPLE:%.*]] = apply [[AUTODIFF_EXTRACT_INST]]([[X_ARG]], [[Y_ARG]], %4) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: ([[ORIG_RESULT:%.*]], [[MULT_DIFF:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE]] : $(Float, @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[DIFF_STRUCT:%.*]] = struct $_AD__binary_bb0__DF__src_0_wrt_0_1 ([[MULT_DIFF]] : $@callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL: [[BINARY_DIFFERENTIAL:%.*]] = function_ref @AD__binary__differential_src_0_wrt_0_1 : $@convention(thin) (Float, Float, @guaranteed _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float -// CHECK-SIL: [[PARTIAL_APP_DIFFERENTIAL:%.*]] = partial_apply [callee_guaranteed] [[BINARY_DIFFERENTIAL]]([[DIFF_STRUCT]]) : $@convention(thin) (Float, Float, @guaranteed _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float +// CHECK-SIL: [[BINARY_DIFFERENTIAL:%.*]] = function_ref @AD__binary__differential_src_0_wrt_0_1 : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float +// CHECK-SIL: [[PARTIAL_APP_DIFFERENTIAL:%.*]] = partial_apply [callee_guaranteed] [[BINARY_DIFFERENTIAL]]([[DIFF_STRUCT]]) : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float // CHECK-SIL: [[RESULT:%.*]] = tuple ([[ORIG_RESULT]] : $Float, [[PARTIAL_APP_DIFFERENTIAL]] : $@callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: return [[RESULT:%.*]] : $(Float, @callee_guaranteed (Float, Float) -> Float) -// CHECK-SIL-LABEL: sil hidden @AD__binary__differential_src_0_wrt_0_1 : $@convention(thin) (Float, Float, @guaranteed _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float { -// CHECK-SIL: bb0([[X_TAN:%.*]] : $Float, [[Y_TAN:%.*]] : $Float, [[DIFF_STRUCT:%.*]] : $_AD__binary_bb0__DF__src_0_wrt_0_1): -// CHECK-SIL: [[MULT_DIFF:%.*]] = struct_extract [[DIFF_STRUCT]] : $_AD__binary_bb0__DF__src_0_wrt_0_1, #_AD__binary_bb0__DF__src_0_wrt_0_1.differential_0 +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__binary__differential_src_0_wrt_0_1 : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float { +// CHECK-SIL: bb0([[X_TAN:%.*]] : $Float, [[Y_TAN:%.*]] : $Float, [[DIFF_STRUCT:%.*]] : @owned $_AD__binary_bb0__DF__src_0_wrt_0_1): +// CHECK-SIL: [[MULT_DIFF:%.*]] = destructure_struct [[DIFF_STRUCT]] : $_AD__binary_bb0__DF__src_0_wrt_0_1 // CHECK-SIL: [[TAN_RESULT:%.*]] = apply [[MULT_DIFF]]([[X_TAN]], [[Y_TAN]]) : $@callee_guaranteed (Float, Float) -> Float // CHECK-SIL: return [[TAN_RESULT]] : $Float