Skip to content

[AutoDiff] Serialize derivative function configurations per module. #28608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ namespace swift {
class VarDecl;
class UnifiedStatsReporter;
class IndexSubset;
// SWIFT_ENABLE_TENSORFLOW
struct AutoDiffConfig;
class VectorSpace;
class DifferentiableAttr;
// SWIFT_ENABLE_TENSORFLOW END

enum class KnownProtocolKind : uint8_t;

Expand Down Expand Up @@ -702,6 +705,21 @@ class ASTContext final {
unsigned previousGeneration,
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods);

// SWIFT_ENABLE_TENSORFLOW
/// Load derivative function configurations for the given
/// AbstractFunctionDecl.
///
/// \param originalAFD The declaration whose derivative function
/// configurations should be loaded.
///
/// \param previousGeneration The previous generation number. The AST already
/// contains derivative function configurations loaded from any generation up
/// to and including this one.
void loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results);
// SWIFT_ENABLE_TENSORFLOW END

/// Retrieve the Clang module loader for this ASTContext.
///
/// If there is no Clang module loader, returns a null pointer.
Expand Down
24 changes: 24 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5694,6 +5694,30 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
private:
ParameterList *Params;

// SWIFT_ENABLE_TENSORFLOW
private:
/// The generation at which we last loaded derivative function configurations.
unsigned DerivativeFunctionConfigGeneration = 0;
/// Prepare to traverse the list of derivative function configurations.
void prepareDerivativeFunctionConfigurations();

/// A uniqued list of derivative function configurations.
/// - `@differentiable` and `@derivative` attribute type-checking is
/// responsible for populating derivative function configurations specified
/// in the current module.
/// - Module loading is responsible for populating derivative function
/// configurations from imported modules.
struct DerivativeFunctionConfigurationList;
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;

public:
/// Get all derivative function configurations.
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();

/// Add the given derivative function configuration.
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
// SWIFT_ENABLE_TENSORFLOW END

protected:
// If a function has a body at all, we have either a parsed body AST node or
// we have saved the end location of the unparsed body.
Expand Down
22 changes: 22 additions & 0 deletions include/swift/AST/ModuleLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class DependencyCollector;

namespace swift {

// SWIFT_ENABLE_TENSORFLOW
struct AutoDiffConfig;
// SWIFT_ENABLE_TENSORFLOW END
class AbstractFunctionDecl;
class ClangImporterOptions;
class ClassDecl;
Expand Down Expand Up @@ -151,6 +154,25 @@ class ModuleLoader {
unsigned previousGeneration,
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) = 0;

// SWIFT_ENABLE_TENSORFLOW
/// Load derivative function configurations for the given
/// AbstractFunctionDecl.
///
/// \param originalAFD The declaration whose derivative function
/// configurations should be loaded.
///
/// \param previousGeneration The previous generation number. The AST already
/// contains derivative function configurations loaded from any generation up
/// to and including this one.
///
/// \param results The result list of derivative function configurations.
/// This list will be extended with any methods found in subsequent
/// generations.
virtual void loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results) {};
// SWIFT_ENABLE_TENSORFLOW END

/// Verify all modules loaded by this loader.
virtual void verifyAllModules() { }
};
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4359,7 +4359,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
/// is necessary for the differentiation transform to support reabstraction
/// thunk differentiation because the function argument is opaque and cannot
/// be differentiated. Instead, the argument is made `@differentiable` and
/// reabstraction thunk JVP/VJP callers are reponsible for passing a
/// reabstraction thunk JVP/VJP callers are responsible for passing a
/// `@differentiable` function.
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
/// derivative approaches. The last argument can simply be a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,19 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
IndexSubset *parameterIndices,
IndexSubset *resultIndices);

/// Finds the "@differentiable" attribute on `original` whose parameter indices
/// are a minimal superset of the specified parameter indices. Returns `nullptr`
/// if no such attribute exists.
/// Finds the derivative configuration (from `@differentiable` and
/// `@derivative` attributes) for `original` whose parameter indices are a
/// minimal superset of the specified AST parameter indices. Returns `None` if
/// no such configuration is found.
///
/// \param parameterIndices must be lowered to SIL.
/// \param minimalParameterIndices is an output parameter that is set to the SIL
/// indices of the minimal attribute, or to `nullptr` if no attribute exists.
const DifferentiableAttr *
getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original,
IndexSubset *parameterIndices,
IndexSubset *&minimalParameterIndices);
/// \param minimalASTParameterIndices is an output parameter that is set to the
/// AST indices of the minimal configuration, or to `nullptr` if no such
/// configuration exists.
Optional<AutoDiffConfig>
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
IndexSubset *parameterIndices,
IndexSubset *&minimalASTParameterIndices);

/// Returns a differentiability witness for `original` whose parameter indices
/// are a minimal superset of the specified parameter indices and whose result
Expand Down
6 changes: 6 additions & 0 deletions include/swift/Serialization/SerializedModuleLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ class SerializedModuleLoaderBase : public ModuleLoader {
unsigned previousGeneration,
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) override;

// SWIFT_ENABLE_TENSORFLOW
virtual void loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results) override;
// SWIFT_ENABLE_TENSORFLOW END

virtual void verifyAllModules() override;
};

Expand Down
13 changes: 13 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,19 @@ void ASTContext::loadObjCMethods(
}
}

// SWIFT_ENABLE_TENSORFLOW
void ASTContext::loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
llvm::SetVector<AutoDiffConfig> &results) {
PrettyStackTraceDecl stackTrace(
"loading derivative function configurations for", originalAFD);
for (auto &loader : getImpl().ModuleLoaders) {
loader->loadDerivativeFunctionConfigurations(originalAFD,
previousGeneration, results);
}
}
// SWIFT_ENABLE_TENSORFLOW END

void ASTContext::verifyAllLoadedModules() const {
#ifndef NDEBUG
FrontendStatsTracer tracer(Stats, "verify-all-loaded-modules");
Expand Down
43 changes: 43 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6998,6 +6998,49 @@ StringRef AbstractFunctionDecl::getInlinableBodyText(
return extractInlinableText(getASTContext().SourceMgr, body, scratch);
}

// SWIFT_ENABLE_TENSORFLOW
/// A uniqued list of derivative function configurations.
struct AbstractFunctionDecl::DerivativeFunctionConfigurationList
: public llvm::SetVector<AutoDiffConfig> {
// Necessary for `ASTContext` allocation.
void *operator new(
size_t bytes, ASTContext &ctx,
unsigned alignment = alignof(DerivativeFunctionConfigurationList)) {
return ctx.Allocate(bytes, alignment);
}
};

void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
if (DerivativeFunctionConfigs)
return;
auto &ctx = getASTContext();
DerivativeFunctionConfigs = new (ctx) DerivativeFunctionConfigurationList();
// Register an `ASTContext` cleanup calling the list destructor.
ctx.addCleanup([this]() {
this->DerivativeFunctionConfigs->~DerivativeFunctionConfigurationList();
});
}

ArrayRef<AutoDiffConfig>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: even though @differentiable and @derivative attributes do not currently store result indices, I decided to use AutoDiffConfig (which contains result indices) to represent AST derivative function configurations. I also tried creating a new "ASTAutoDiffConfig" struct containing just parameter indices and derivative generic signature, but decided against the code duplication.

I think the direction is to eventually plumb result indices through the differentiation system (TF-1038), so preemptively using result indices (AutoDiffConfig) seems good.

AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
prepareDerivativeFunctionConfigurations();
auto &ctx = getASTContext();
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
unsigned previousGeneration = DerivativeFunctionConfigGeneration;
DerivativeFunctionConfigGeneration = ctx.getCurrentGeneration();
ctx.loadDerivativeFunctionConfigurations(this, previousGeneration,
*DerivativeFunctionConfigs);
}
return DerivativeFunctionConfigs->getArrayRef();
}

void AbstractFunctionDecl::addDerivativeFunctionConfiguration(
AutoDiffConfig config) {
prepareDerivativeFunctionConfigurations();
DerivativeFunctionConfigs->insert(config);
}
// SWIFT_ENABLE_TENSORFLOW END

FuncDecl *FuncDecl::createImpl(ASTContext &Context,
SourceLoc StaticLoc,
StaticSpellingKind StaticSpelling,
Expand Down
41 changes: 20 additions & 21 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,27 +748,27 @@ emitDerivativeFunctionReference(
original, invoker, diag::autodiff_protocol_member_not_differentiable);
return None;
}
// Get the minimal `@differentiable` attribute and parameter index subset.
IndexSubset *minimalParamIndexSet = nullptr;
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
requirementDecl, desiredIndices.parameters, minimalParamIndexSet);
SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet);
// If minimal `@differentiable` attribute does not exist, then no attribute
// exists with a superset of the desired indices. Produce an error.
if (!minimalAttr) {
// Find the minimal derivative configuration: minimal parameter indices and
// corresponding derivative generic signature. If it does not exist, produce
// an error.
IndexSubset *minimalASTParamIndices = nullptr;
auto minimalConfig = findMinimalDerivativeConfiguration(
requirementDecl, desiredIndices.parameters, minimalASTParamIndices);
if (!minimalConfig) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_member_subset_indices_not_differentiable);
return None;
}
auto minimalIndices = minimalConfig->getSILAutoDiffIndices();
// Emit a `witness_method` instruction for the derivative function.
auto originalType = witnessMethod->getType().castTo<SILFunctionType>();
auto assocType = originalType->getAutoDiffDerivativeFunctionType(
minimalIndices.parameters, minimalIndices.source,
kind, context.getTypeConverter(),
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
kind, minimalAttr->getParameterIndices(), context.getASTContext());
kind, minimalASTParamIndices, context.getASTContext());
auto *ref = builder.createWitnessMethod(
loc, witnessMethod->getLookupType(), witnessMethod->getConformance(),
requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
Expand All @@ -792,28 +792,27 @@ emitDerivativeFunctionReference(
original, invoker, diag::autodiff_class_member_not_differentiable);
return None;
}
// Get the minimal `@differentiable` attribute and parameter index subset.
IndexSubset *minimalParamIndexSet = nullptr;
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
methodDecl, desiredIndices.parameters, minimalParamIndexSet);
SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet);
// If minimal `@differentiable` attribute does not exist, then no attribute
// exists with a superset of the desired indices. Produce an error.
if (!minimalAttr) {
// Find the minimal derivative configuration: minimal parameter indices and
// corresponding derivative generic signature. If it does not exist, produce
// an error.
IndexSubset *minimalASTParamIndices = nullptr;
auto minimalConfig = findMinimalDerivativeConfiguration(
methodDecl, desiredIndices.parameters, minimalASTParamIndices);
if (!minimalConfig) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_member_subset_indices_not_differentiable);
return None;
}
auto minimalIndices = minimalConfig->getSILAutoDiffIndices();
// Emit a `class_method` instruction for the derivative function.
auto originalType = classMethodInst->getType().castTo<SILFunctionType>();
auto assocType = originalType->getAutoDiffDerivativeFunctionType(
minimalIndices.parameters, minimalIndices.source,
kind, context.getTypeConverter(),
minimalIndices.parameters, minimalIndices.source, kind,
context.getTypeConverter(),
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
kind, minimalAttr->getParameterIndices(),
context.getASTContext());
kind, minimalASTParamIndices, context.getASTContext());
auto *ref = builder.createClassMethod(
loc, classMethodInst->getOperand(),
methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
Expand Down
63 changes: 28 additions & 35 deletions lib/SILOptimizer/Utils/Differentiation/DerivativeLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,34 +45,35 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
return nullptr;
}

const DifferentiableAttr *
getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original,
IndexSubset *parameterIndices,
IndexSubset *&minimalParameterIndices) {
const DifferentiableAttr *minimalAttr = nullptr;
minimalParameterIndices = nullptr;
for (auto *attr : original->getAttrs().getAttributes<DifferentiableAttr>()) {
auto *attrParameterIndices = autodiff::getLoweredParameterIndices(
attr->getParameterIndices(),
Optional<AutoDiffConfig>
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
IndexSubset *parameterIndices,
IndexSubset *&minimalASTParameterIndices) {
Optional<AutoDiffConfig> minimalConfig = None;
auto configs = original->getDerivativeFunctionConfigurations();
for (auto config : configs) {
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
config.parameterIndices,
original->getInterfaceType()->castTo<AnyFunctionType>());
// If all indices in `parameterIndices` are in `daParameterIndices`, and it
// has fewer indices than our current candidate and a primitive VJP, then
// `attr` is our new candidate.
// If all indices in `parameterIndices` are in `daParameterIndices`, and
// it has fewer indices than our current candidate and a primitive VJP,
// then `attr` is our new candidate.
//
// NOTE(TF-642): `attr` may come from a un-partial-applied function and
// have larger capacity than the desired indices. We expect this logic to
// go away when `partial_apply` supports `@differentiable` callees.
if (attrParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
original->getASTContext(), attrParameterIndices->getCapacity())) &&
if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
original->getASTContext(), silParameterIndices->getCapacity())) &&
// fewer parameters than before
(!minimalParameterIndices ||
attrParameterIndices->getNumIndices() <
minimalParameterIndices->getNumIndices())) {
minimalAttr = attr;
minimalParameterIndices = attrParameterIndices;
(!minimalConfig ||
silParameterIndices->getNumIndices() <
minimalConfig->parameterIndices->getNumIndices())) {
minimalASTParameterIndices = config.parameterIndices;
minimalConfig = AutoDiffConfig(silParameterIndices, config.resultIndices,
config.derivativeGenericSignature);
}
}
return minimalAttr;
return minimalConfig;
}

SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
Expand All @@ -88,22 +89,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
if (!originalAFD)
return nullptr;

IndexSubset *minimalParameterIndices = nullptr;
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
originalAFD, parameterIndices, minimalParameterIndices);

// TODO(TF-835): This will also need to search all `@differentiating`
// attributes after we stop synthesizing `@differentiable` attributes for
// `@differentiating` attributes.

if (!minimalAttr)
IndexSubset *minimalASTParameterIndices = nullptr;
auto minimalConfig = findMinimalDerivativeConfiguration(
originalAFD, parameterIndices, minimalASTParameterIndices);
if (!minimalConfig)
return nullptr;

AutoDiffConfig minimalConfig(minimalParameterIndices, resultIndices,
minimalAttr->getDerivativeGenericSignature());

auto *existingWitness = module.lookUpDifferentiabilityWitness(
{original->getName(), minimalConfig});
{original->getName(), *minimalConfig});
if (existingWitness)
return existingWitness;

Expand All @@ -113,8 +106,8 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(

return SILDifferentiabilityWitness::createDeclaration(
module, SILLinkage::PublicExternal, original,
minimalConfig.parameterIndices, minimalConfig.resultIndices,
minimalConfig.derivativeGenericSignature);
minimalConfig->parameterIndices, minimalConfig->resultIndices,
minimalConfig->derivativeGenericSignature);
}

} // end namespace swift
Loading