Skip to content

Commit

Permalink
Make changes for ownership change.
Browse files Browse the repository at this point in the history
  • Loading branch information
bartchr808 committed Aug 19, 2019
1 parent e4025d3 commit 1883ac1
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 85 deletions.
182 changes: 125 additions & 57 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4005,6 +4005,12 @@ class JVPEmitter final
/// Mapping from differential basic blocks to differential struct arguments.
DenseMap<SILBasicBlock *, SILArgument *> differentialStructArguments;

/// Mapping from differential struct field declarations to differential struct
/// elements destructured from the linear map basic block argument. In the
/// beginning of each differential basic block, the block's differential struct is
/// destructured into individual elements stored here.
DenseMap<VarDecl *, SILValue> differentialStructElements;

/// Mapping from original basic blocks and original values to corresponding
/// tangent values.
DenseMap<SILValue, AdjointValue> tangentValueMap;
Expand Down Expand Up @@ -4079,6 +4085,35 @@ class JVPEmitter final
return SILBuilder(*differential);
}

//--------------------------------------------------------------------------//
// Differential struct mapping
//--------------------------------------------------------------------------//

void initializeDifferentialStructElements(SILBasicBlock *origBB,
SILInstructionResultArray values) {
auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB);
assert(diffStructDecl->getStoredProperties().size() == values.size() &&
"The number of differential struct fields must equal the number of "
"differential struct element values");
for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) {
assert(
std::get<1>(pair).getOwnershipKind() != ValueOwnershipKind::Guaranteed
&& "Differential struct elements must be @owned");
auto insertion = differentialStructElements.insert({std::get<0>(pair),
std::get<1>(pair)});
(void)insertion;
assert(insertion.second && "A differential struct element already exists!");
}
}

SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field) {
assert(differentialInfo.getLinearMapStruct(origBB) ==
cast<StructDecl>(field->getDeclContext()));
assert(differentialStructElements.count(field) &&
"Differential struct element for this field does not exist!");
return differentialStructElements.lookup(field);
}

//--------------------------------------------------------------------------//
// General utilities
//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -4209,8 +4244,8 @@ class JVPEmitter final
type, ResilienceExpansion::Minimal);
auto *buffer = diffBuilder.createAllocStack(loc, silType);
emitZeroIndirect(type, buffer, loc);
auto *loaded = diffBuilder.createLoad(
loc, buffer, LoadOwnershipQualifier::Unqualified);
auto loaded = diffBuilder.emitLoadValueOperation(
loc, buffer, LoadOwnershipQualifier::Take);
diffBuilder.createDeallocStack(loc, buffer);
return loaded;
}
Expand Down Expand Up @@ -4256,38 +4291,12 @@ class JVPEmitter final
}

SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) {
auto diffBuilder = getDifferentialBuilder();
assert(originalBuffer->getType().isAddress());
assert(originalBuffer->getFunction() == original);
auto insertion = bufferMap.try_emplace({origBB, originalBuffer},
SILValue());
if (!insertion.second) // not inserted.
return insertion.first->getSecond();

// Set insertion point for local allocation builder: before the last local
// allocation, or at the start of the tangent function's entry if no local
// allocations exist yet.
diffLocalAllocBuilder.setInsertionPoint(
getDifferential().getEntryBlock(),
getNextDifferentialLocalAllocationInsertionPoint());

// Allocate local buffer and initialize to zero.
auto bufObjectType = getRemappedTangentType(originalBuffer->getType());
auto *newBuf = diffLocalAllocBuilder.createAllocStack(
originalBuffer.getLoc(), bufObjectType);

// Temporarily change global builder insertion point and emit zero into the
// local buffer.
auto insertionPoint = diffLocalAllocBuilder.getInsertionBB();
diffBuilder.setInsertionPoint(
diffLocalAllocBuilder.getInsertionBB(),
diffLocalAllocBuilder.getInsertionPoint());
emitZeroIndirect(bufObjectType.getASTType(), newBuf, newBuf->getLoc());
diffBuilder.setInsertionPoint(insertionPoint);

// Create cleanup for local buffer.
differentialLocalAllocations.push_back(newBuf);
return (insertion.first->getSecond() = newBuf);
assert(!insertion.second && "tangent buffer should already exist");
return insertion.first->getSecond();
}

//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -4345,6 +4354,38 @@ class JVPEmitter final
// Tangent emission helpers
//--------------------------------------------------------------------------//

void emitTangentForDestroyValueInst(DestroyValueInst *dvi) {
auto &diffBuilder = getDifferentialBuilder();
auto loc = dvi->getLoc();
auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc);
diffBuilder.emitDestroyValue(loc, tanVal);
}

void emitTangentForBeginBorrow(BeginBorrowInst *bbi) {
auto &diffBuilder = getDifferentialBuilder();
auto loc = bbi->getLoc();
auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc);
auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal);
setTangentValue(bbi->getParent(), bbi,
makeConcreteTangentValue(tanValBorrow));
}

void emitTangentForEndBorrow(EndBorrowInst *ebi) {
auto &diffBuilder = getDifferentialBuilder();
auto loc = ebi->getLoc();
auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc);
diffBuilder.emitEndBorrowOperation(loc, tanVal);
}

void emitTangentForCopyValueInst(CopyValueInst *cvi) {
auto &diffBuilder = getDifferentialBuilder();
auto tan = getTangentValue(cvi->getOperand());
auto tanVal = materializeTangent(tan, cvi->getLoc());
auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal);
setTangentValue(cvi->getParent(), cvi,
makeConcreteTangentValue(tanValCopy));
}

void emitTangentForReturnInst(ReturnInst *ri) {
auto loc = ri->getOperand().getLoc();
auto diffBuilder = getDifferentialBuilder();
Expand All @@ -4366,8 +4407,7 @@ class JVPEmitter final
// Get the differential.
auto *field = differentialInfo.lookUpLinearMapDecl(ai);
assert(field);
SILValue differential = diffBuilder.createStructExtract(
loc, getDifferentialStructArgument(ai->getParent()), field);
SILValue differential = getDifferentialStructElement(bb, field);

SmallVector<SILValue, 8> diffArgs;
for (auto origArg : ai->getArguments()) {
Expand Down Expand Up @@ -4400,6 +4440,7 @@ class JVPEmitter final
auto *differentialCall = diffBuilder.createApply(
loc, differential, SubstitutionMap(), diffArgs,
/*isNonThrowing*/ false);
diffBuilder.emitDestroyValueOperation(loc, differential);
assert(differentialCall->getNumResults() == 1 &&
"Expected differential to return one result");

Expand All @@ -4423,6 +4464,8 @@ class JVPEmitter final
}

void startDifferentialGeneration() {
auto &diffBuilder = getDifferentialBuilder();

// Create differential blocks and arguments.
// TODO: Consider visiting original blocks in pre-order (dominance) order.
SmallVector<SILBasicBlock *, 8> preOrderDomOrder;
Expand All @@ -4439,9 +4482,9 @@ class JVPEmitter final
if (&origBB == origEntry) {
assert(diffBB->isEntry());
createEntryArguments(&differential);
auto *lastArg = diffBB->getArguments().back();
assert(lastArg->getType() == diffStructLoweredType);
differentialStructArguments[&origBB] = lastArg;
auto *mainDifferentialStruct = diffBB->getArguments().back();
assert(mainDifferentialStruct->getType() == diffStructLoweredType);
differentialStructArguments[&origBB] = mainDifferentialStruct;
}

LLVM_DEBUG({
Expand All @@ -4460,7 +4503,6 @@ class JVPEmitter final

// The differential function has type:
// (arg0', ..., argn', exit_diffs) -> result'.
auto &diffBuilder = getDifferentialBuilder();
auto diffParamArgs =
differential.getArgumentsWithoutIndirectResults().drop_back();
assert(diffParamArgs.size() == attr->getIndices().parameters->getCapacity());
Expand All @@ -4484,15 +4526,12 @@ class JVPEmitter final
}

auto *diffEntry = getDifferential().getEntryBlock();
auto diffLoc = getDifferential().getLocation();
diffBuilder.setInsertionPoint(
diffEntry, getNextDifferentialLocalAllocationInsertionPoint());

for (auto index : *getIndices().parameters) {
auto diffParam = diffParamArgs[index];
auto origParam = origParamArgs[index];
diffBuilder.createRetainValue(diffLoc, diffParam,
diffBuilder.getDefaultAtomicity());
setTangentValue(origEntry, origParam,
makeConcreteTangentValue(diffParam));
LLVM_DEBUG(getADDebugStream()
Expand Down Expand Up @@ -4571,7 +4610,7 @@ class JVPEmitter final
auto *dfStruct = linearMapInfo->getLinearMapStruct(origEntry);
auto dfStructType =
dfStruct->getDeclaredInterfaceType()->getCanonicalType();
dfParams.push_back({dfStructType, ParameterConvention::Direct_Guaranteed});
dfParams.push_back({dfStructType, ParameterConvention::Direct_Owned});

auto diffName = original->getASTContext()
.getIdentifier("AD__" + original->getName().str() +
Expand All @@ -4593,7 +4632,6 @@ class JVPEmitter final
linkage, diffName, diffType, diffGenericEnv, original->getLocation(),
original->isBare(), IsNotTransparent, original->isSerialized(),
original->isDynamicallyReplaceable());
differential->setOwnershipEliminated();
differential->setDebugScope(
new (module) SILDebugScope(original->getLocation(), differential));

Expand Down Expand Up @@ -4649,6 +4687,14 @@ class JVPEmitter final
errorOccurred = true;
}

/// Handle `copy_value` instruction.
/// Original: y = copy_value x
/// Adjoint: tan[x] = copy_value tan[y]
void visitCopyValueInst(CopyValueInst *cvi) {
TypeSubstCloner::visitCopyValueInst(cvi);
emitTangentForCopyValueInst(cvi);
}

void visitReturnInst(ReturnInst *ri) {
auto loc = ri->getOperand().getLoc();
auto *origExit = ri->getParent();
Expand Down Expand Up @@ -4682,6 +4728,39 @@ class JVPEmitter final
emitTangentForReturnInst(ri);
}

void visitInstructionsInBlock(SILBasicBlock *bb) {
// Destructure the differential struct to get the elements.
if (bb == original->getEntryBlock()) {
auto &diffBuilder = getDifferentialBuilder();
auto diffLoc = getDifferential().getLocation();
auto *diffBB = diffBBMap.lookup(bb);
auto *mainDifferentialStruct = diffBB->getArguments().back();
diffBuilder.setInsertionPoint(diffBB);
auto *dsi = diffBuilder.createDestructureStruct(
diffLoc, mainDifferentialStruct);
initializeDifferentialStructElements(bb, dsi->getResults());
}
TypeSubstCloner::visitInstructionsInBlock(bb);
}

void visitDestroyValueInst(DestroyValueInst *dvi) {
TypeSubstCloner::visitDestroyValueInst(dvi);
if (shouldBeDifferentiated(dvi, getIndices()))
emitTangentForDestroyValueInst(dvi);
}

void visitBeginBorrowInst(BeginBorrowInst *bbi) {
TypeSubstCloner::visitBeginBorrowInst(bbi);
if (shouldBeDifferentiated(bbi, getIndices()))
emitTangentForBeginBorrow(bbi);
}

void visitEndBorrowInst(EndBorrowInst *ebi) {
TypeSubstCloner::visitEndBorrowInst(ebi);
if (shouldBeDifferentiated(ebi, getIndices()))
emitTangentForEndBorrow(ebi);
}

// If an `apply` has active results or active inout parameters, replace it
// with an `apply` of its JVP.
void visitApplyInst(ApplyInst *ai) {
Expand Down Expand Up @@ -4757,7 +4836,6 @@ class JVPEmitter final
auto loc = ai->getLoc();
auto &builder = getBuilder();
auto original = getOpValue(ai->getCallee());
auto functionSource = original;
SILValue jvpValue;
// If functionSource is a @differentiable function, just extract it.
auto originalFnTy = original->getType().castTo<SILFunctionType>();
Expand All @@ -4771,9 +4849,11 @@ class JVPEmitter final
return;
}
}
auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original);
jvpValue = builder.createAutoDiffFunctionExtract(
loc, AutoDiffFunctionExtractInst::Extractee::JVP,
/*differentiationOrder*/ 1, functionSource);
/*differentiationOrder*/ 1, borrowedDiffFunc);
jvpValue = builder.emitCopyValueOperation(loc, jvpValue);
}

// Check and diagnose non-differentiable arguments.
Expand Down Expand Up @@ -4822,7 +4902,7 @@ class JVPEmitter final
// function operand is specialized with a remapped version of same
// substitution map using an argument-less `partial_apply`.
if (ai->getSubstitutionMap().empty()) {
builder.createRetainValue(loc, original, builder.getDefaultAtomicity());
original = builder.emitCopyValueOperation(loc, original);
} else {
auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap());
auto jvpPartialApply = getBuilder().createPartialApply(
Expand Down Expand Up @@ -4863,9 +4943,7 @@ class JVPEmitter final
LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall);

// Release the differentiable function.
if (differentiableFunc)
builder.createReleaseValue(loc, differentiableFunc,
builder.getDefaultAtomicity());
builder.emitDestroyValueOperation(loc, jvpValue);

// Get the JVP results (original results and differential).
SmallVector<SILValue, 8> jvpDirectResults;
Expand Down Expand Up @@ -5020,16 +5098,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
// Pullback struct mapping
//--------------------------------------------------------------------------//

SILArgument *getPullbackBlockPullbackStructArgument(SILBasicBlock *origBB) {
#ifndef NDEBUG
assert(origBB->getParent() == &getOriginal());
auto *pbStruct = pullbackStructArguments[origBB]->getType()
.getStructOrBoundGenericStruct();
assert(pbStruct == getPullbackInfo().getLinearMapStruct(origBB));
#endif
return pullbackStructArguments[origBB];
}

void initializePullbackStructElements(SILBasicBlock *origBB,
SILInstructionResultArray values) {
auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB);
Expand Down
Loading

0 comments on commit 1883ac1

Please sign in to comment.