Skip to content

Commit ceb8ce0

Browse files
committed
[AutoDiff] Serialize derivative function configurations per module.
`@differentiable` and `@derivative` attributes register derivatives for `AbstractFunctionDecl`s for a particular "derivative function configuration": parameter indices and dervative generic signature. To find `@derivative` functions registered in other Swift modules, derivative function configurations must be serialized per module. When configurations for a `AbstractFunctionDecl` are requested, all configurations from imported modules are deserialized. This module serialization technique has precedent: it is used for protocol conformances (e.g. extension declarations for a nominal type) and Obj-C members for a class type. Add `AbstractFunctionDecl::getDerivativeFunctionConfigurations` entry point for accessing derivative function configurations. Use `AbstractFunctionDecl::getDerivativeFunctionConfigurations` to implement `findMinimalDerivativeConfiguration` for canonical derivative function configuration lookup, replacing `getMinimalASTDifferentiableAttr`. Unblocks TF-815: lowering `@derivative` attributes directly to SIL differentiability witnesses without generating implicit `@differentiable` attributes.
1 parent 542c236 commit ceb8ce0

File tree

19 files changed

+531
-72
lines changed

19 files changed

+531
-72
lines changed

include/swift/AST/ASTContext.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ namespace swift {
110110
class VarDecl;
111111
class UnifiedStatsReporter;
112112
class IndexSubset;
113+
// SWIFT_ENABLE_TENSORFLOW
114+
struct AutoDiffConfig;
113115
class VectorSpace;
114116
class DifferentiableAttr;
117+
// SWIFT_ENABLE_TENSORFLOW END
115118

116119
enum class KnownProtocolKind : uint8_t;
117120

@@ -702,6 +705,21 @@ class ASTContext final {
702705
unsigned previousGeneration,
703706
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods);
704707

708+
// SWIFT_ENABLE_TENSORFLOW
709+
/// Load derivative function configurations for the given
710+
/// AbstractFunctionDecl.
711+
///
712+
/// \param originalAFD The declaration whose derivative function
713+
/// configurations should be loaded.
714+
///
715+
/// \param previousGeneration The previous generation number. The AST already
716+
/// contains derivative function configurations loaded from any generation up
717+
/// to and including this one.
718+
void loadDerivativeFunctionConfigurations(
719+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
720+
llvm::SetVector<AutoDiffConfig> &results);
721+
// SWIFT_ENABLE_TENSORFLOW END
722+
705723
/// Retrieve the Clang module loader for this ASTContext.
706724
///
707725
/// If there is no Clang module loader, returns a null pointer.

include/swift/AST/Decl.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5694,6 +5694,25 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
56945694
private:
56955695
ParameterList *Params;
56965696

5697+
// SWIFT_ENABLE_TENSORFLOW
5698+
private:
5699+
/// The generation at which we last loaded derivative function configurations.
5700+
unsigned DerivativeFunctionConfigGeneration = 0;
5701+
/// Prepare to traverse the list of derivative function configurations.
5702+
void prepareDerivativeFunctionConfigurations();
5703+
5704+
/// A uniqued list of derivative function configurations.
5705+
struct DerivativeFunctionConfigurationList;
5706+
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
5707+
5708+
public:
5709+
/// Get all derivative function configurations.
5710+
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
5711+
5712+
/// Add the given derivative function configuration.
5713+
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
5714+
// SWIFT_ENABLE_TENSORFLOW END
5715+
56975716
protected:
56985717
// If a function has a body at all, we have either a parsed body AST node or
56995718
// we have saved the end location of the unparsed body.

include/swift/AST/ModuleLoader.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class DependencyCollector;
3434

3535
namespace swift {
3636

37+
// SWIFT_ENABLE_TENSORFLOW
38+
struct AutoDiffConfig;
39+
// SWIFT_ENABLE_TENSORFLOW END
3740
class AbstractFunctionDecl;
3841
class ClangImporterOptions;
3942
class ClassDecl;
@@ -151,6 +154,25 @@ class ModuleLoader {
151154
unsigned previousGeneration,
152155
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) = 0;
153156

157+
// SWIFT_ENABLE_TENSORFLOW
158+
/// Load derivative function configurations for the given
159+
/// AbstractFunctionDecl.
160+
///
161+
/// \param originalAFD The declaration whose derivative function
162+
/// configurations should be loaded.
163+
///
164+
/// \param previousGeneration The previous generation number. The AST already
165+
/// contains derivative function configurations loaded from any generation up
166+
/// to and including this one.
167+
///
168+
/// \param results The result list of derivative function configurations.
169+
/// This list will be extended with any methods found in subsequent
170+
/// generations.
171+
virtual void loadDerivativeFunctionConfigurations(
172+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
173+
llvm::SetVector<AutoDiffConfig> &results){};
174+
// SWIFT_ENABLE_TENSORFLOW END
175+
154176
/// Verify all modules loaded by this loader.
155177
virtual void verifyAllModules() { }
156178
};

include/swift/AST/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4359,7 +4359,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
43594359
/// is necessary for the differentiation transform to support reabstraction
43604360
/// thunk differentiation because the function argument is opaque and cannot
43614361
/// be differentiated. Instead, the argument is made `@differentiable` and
4362-
/// reabstraction thunk JVP/VJP callers are reponsible for passing a
4362+
/// reabstraction thunk JVP/VJP callers are responsible for passing a
43634363
/// `@differentiable` function.
43644364
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
43654365
/// derivative approaches. The last argument can simply be a

include/swift/SILOptimizer/Utils/Differentiation/DerivativeLookup.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,26 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
3636
IndexSubset *parameterIndices,
3737
IndexSubset *resultIndices);
3838

39-
/// Finds the "@differentiable" attribute on `original` whose parameter indices
40-
/// are a minimal superset of the specified parameter indices. Returns `nullptr`
41-
/// if no such attribute exists.
39+
/// Finds the derivative configuration (from `@differentiable` and
40+
/// `@derivative` attributes) for `original` whose parameter indices are a
41+
/// minimal superset of the specified AST parameter indices. Returns true if
42+
/// such a configuration is found.
4243
///
4344
/// \param parameterIndices must be lowered to SIL.
44-
/// \param minimalParameterIndices is an output parameter that is set to the SIL
45-
/// indices of the minimal attribute, or to `nullptr` if no attribute exists.
46-
const DifferentiableAttr *
47-
getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original,
48-
IndexSubset *parameterIndices,
49-
IndexSubset *&minimalParameterIndices);
45+
/// \param minimalASTParameterIndices is an output parameter that is set to the
46+
/// AST indices of the minimal configuration, or to `nullptr` if no such
47+
/// configuration exists.
48+
/// \param minimalSILParameterIndices is an output parameter that is set to the
49+
/// SIL indices of the minimal configuration, or to `nullptr` if no such
50+
/// configuration exists.
51+
/// \param derivativeGenericSignature is an output parameter that is set to the
52+
/// derivative generic signature of the minimal configuration, or the `nullptr`
53+
/// if no such configuration exists.
54+
bool findMinimalDerivativeConfiguration(
55+
AbstractFunctionDecl *original, IndexSubset *parameterIndices,
56+
IndexSubset *&minimalASTParameterIndices,
57+
IndexSubset *&minimalSILParameterIndices,
58+
GenericSignature &derivativeGenericSignature);
5059

5160
/// Returns a differentiability witness for `original` whose parameter indices
5261
/// are a minimal superset of the specified parameter indices and whose result

include/swift/Serialization/SerializedModuleLoader.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ class SerializedModuleLoaderBase : public ModuleLoader {
166166
unsigned previousGeneration,
167167
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) override;
168168

169+
// SWIFT_ENABLE_TENSORFLOW
170+
virtual void loadDerivativeFunctionConfigurations(
171+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
172+
llvm::SetVector<AutoDiffConfig> &results) override;
173+
// SWIFT_ENABLE_TENSORFLOW END
174+
169175
virtual void verifyAllModules() override;
170176
};
171177

lib/AST/ASTContext.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,19 @@ void ASTContext::loadObjCMethods(
16111611
}
16121612
}
16131613

1614+
// SWIFT_ENABLE_TENSORFLOW
1615+
void ASTContext::loadDerivativeFunctionConfigurations(
1616+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
1617+
llvm::SetVector<AutoDiffConfig> &results) {
1618+
PrettyStackTraceDecl stackTrace(
1619+
"loading derivative function configurations for", originalAFD);
1620+
for (auto &loader : getImpl().ModuleLoaders) {
1621+
loader->loadDerivativeFunctionConfigurations(originalAFD,
1622+
previousGeneration, results);
1623+
}
1624+
}
1625+
// SWIFT_ENABLE_TENSORFLOW END
1626+
16141627
void ASTContext::verifyAllLoadedModules() const {
16151628
#ifndef NDEBUG
16161629
FrontendStatsTracer tracer(Stats, "verify-all-loaded-modules");

lib/AST/Decl.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6998,6 +6998,49 @@ StringRef AbstractFunctionDecl::getInlinableBodyText(
69986998
return extractInlinableText(getASTContext().SourceMgr, body, scratch);
69996999
}
70007000

7001+
// SWIFT_ENABLE_TENSORFLOW
7002+
/// A uniqued list of derivative function configurations.
7003+
struct AbstractFunctionDecl::DerivativeFunctionConfigurationList
7004+
: public llvm::SetVector<AutoDiffConfig> {
7005+
// Necessary for `ASTContext` allocation.
7006+
void *operator new(
7007+
size_t bytes, ASTContext &ctx,
7008+
unsigned alignment = alignof(DerivativeFunctionConfigurationList)) {
7009+
return ctx.Allocate(bytes, alignment);
7010+
}
7011+
};
7012+
7013+
void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
7014+
if (DerivativeFunctionConfigs)
7015+
return;
7016+
auto &ctx = getASTContext();
7017+
DerivativeFunctionConfigs = new (ctx) DerivativeFunctionConfigurationList();
7018+
// Register an `ASTContext` cleanup calling the list destructor.
7019+
ctx.addCleanup([this]() {
7020+
this->DerivativeFunctionConfigs->~DerivativeFunctionConfigurationList();
7021+
});
7022+
}
7023+
7024+
ArrayRef<AutoDiffConfig>
7025+
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
7026+
prepareDerivativeFunctionConfigurations();
7027+
auto &ctx = getASTContext();
7028+
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
7029+
unsigned previousGeneration = DerivativeFunctionConfigGeneration;
7030+
DerivativeFunctionConfigGeneration = ctx.getCurrentGeneration();
7031+
ctx.loadDerivativeFunctionConfigurations(this, previousGeneration,
7032+
*DerivativeFunctionConfigs);
7033+
}
7034+
return DerivativeFunctionConfigs->getArrayRef();
7035+
}
7036+
7037+
void AbstractFunctionDecl::addDerivativeFunctionConfiguration(
7038+
AutoDiffConfig config) {
7039+
prepareDerivativeFunctionConfigurations();
7040+
DerivativeFunctionConfigs->insert(config);
7041+
}
7042+
// SWIFT_ENABLE_TENSORFLOW END
7043+
70017044
FuncDecl *FuncDecl::createImpl(ASTContext &Context,
70027045
SourceLoc StaticLoc,
70037046
StaticSpellingKind StaticSpelling,

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -748,27 +748,29 @@ emitDerivativeFunctionReference(
748748
original, invoker, diag::autodiff_protocol_member_not_differentiable);
749749
return None;
750750
}
751-
// Get the minimal `@differentiable` attribute and parameter index subset.
752-
IndexSubset *minimalParamIndexSet = nullptr;
753-
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
754-
requirementDecl, desiredIndices.parameters, minimalParamIndexSet);
755-
SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet);
756-
// If minimal `@differentiable` attribute does not exist, then no attribute
757-
// exists with a superset of the desired indices. Produce an error.
758-
if (!minimalAttr) {
751+
// Find the minimal derivative configuration: minimal parameter indices and
752+
// corresponding derivative generic signature. If it does not exist, produce
753+
// an error.
754+
IndexSubset *minimalASTParamIndices = nullptr;
755+
IndexSubset *minimalSILParamIndices = nullptr;
756+
GenericSignature derivativeGenericSignature;
757+
if (!findMinimalDerivativeConfiguration(
758+
requirementDecl, desiredIndices.parameters, minimalASTParamIndices,
759+
minimalSILParamIndices, derivativeGenericSignature)) {
759760
context.emitNondifferentiabilityError(
760761
original, invoker,
761762
diag::autodiff_member_subset_indices_not_differentiable);
762763
return None;
763764
}
765+
SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalSILParamIndices);
764766
// Emit a `witness_method` instruction for the derivative function.
765767
auto originalType = witnessMethod->getType().castTo<SILFunctionType>();
766768
auto assocType = originalType->getAutoDiffDerivativeFunctionType(
767769
minimalIndices.parameters, minimalIndices.source,
768770
kind, context.getTypeConverter(),
769771
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
770772
auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
771-
kind, minimalAttr->getParameterIndices(), context.getASTContext());
773+
kind, minimalASTParamIndices, context.getASTContext());
772774
auto *ref = builder.createWitnessMethod(
773775
loc, witnessMethod->getLookupType(), witnessMethod->getConformance(),
774776
requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
@@ -792,28 +794,29 @@ emitDerivativeFunctionReference(
792794
original, invoker, diag::autodiff_class_member_not_differentiable);
793795
return None;
794796
}
795-
// Get the minimal `@differentiable` attribute and parameter index subset.
796-
IndexSubset *minimalParamIndexSet = nullptr;
797-
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
798-
methodDecl, desiredIndices.parameters, minimalParamIndexSet);
799-
SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet);
800-
// If minimal `@differentiable` attribute does not exist, then no attribute
801-
// exists with a superset of the desired indices. Produce an error.
802-
if (!minimalAttr) {
797+
// Find the minimal derivative configuration: minimal parameter indices and
798+
// corresponding derivative generic signature. If it does not exist, produce
799+
// an error.
800+
IndexSubset *minimalASTParamIndices = nullptr;
801+
IndexSubset *minimalSILParamIndices = nullptr;
802+
GenericSignature derivativeGenericSignature;
803+
if (!findMinimalDerivativeConfiguration(
804+
methodDecl, desiredIndices.parameters, minimalASTParamIndices,
805+
minimalSILParamIndices, derivativeGenericSignature)) {
803806
context.emitNondifferentiabilityError(
804807
original, invoker,
805808
diag::autodiff_member_subset_indices_not_differentiable);
806809
return None;
807810
}
811+
SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalSILParamIndices);
808812
// Emit a `class_method` instruction for the derivative function.
809813
auto originalType = classMethodInst->getType().castTo<SILFunctionType>();
810814
auto assocType = originalType->getAutoDiffDerivativeFunctionType(
811-
minimalIndices.parameters, minimalIndices.source,
812-
kind, context.getTypeConverter(),
815+
minimalIndices.parameters, minimalIndices.source, kind,
816+
context.getTypeConverter(),
813817
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
814818
auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
815-
kind, minimalAttr->getParameterIndices(),
816-
context.getASTContext());
819+
kind, minimalASTParamIndices, context.getASTContext());
817820
auto *ref = builder.createClassMethod(
818821
loc, classMethodInst->getOperand(),
819822
methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),

lib/SILOptimizer/Utils/Differentiation/DerivativeLookup.cpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,34 +45,36 @@ getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
4545
return nullptr;
4646
}
4747

48-
const DifferentiableAttr *
49-
getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original,
50-
IndexSubset *parameterIndices,
51-
IndexSubset *&minimalParameterIndices) {
52-
const DifferentiableAttr *minimalAttr = nullptr;
53-
minimalParameterIndices = nullptr;
54-
for (auto *attr : original->getAttrs().getAttributes<DifferentiableAttr>()) {
55-
auto *attrParameterIndices = autodiff::getLoweredParameterIndices(
56-
attr->getParameterIndices(),
57-
original->getInterfaceType()->castTo<AnyFunctionType>());
58-
// If all indices in `parameterIndices` are in `daParameterIndices`, and it
59-
// has fewer indices than our current candidate and a primitive VJP, then
60-
// `attr` is our new candidate.
48+
bool findMinimalDerivativeConfiguration(
49+
AbstractFunctionDecl *original, IndexSubset *parameterIndices,
50+
IndexSubset *&minimalASTParameterIndices,
51+
IndexSubset *&minimalSILParameterIndices,
52+
GenericSignature &derivativeGenericSignature) {
53+
auto configs = original->getDerivativeFunctionConfigurations();
54+
for (auto config : configs) {
55+
auto *paramIndices = config.parameterIndices;
56+
auto derivativeGenSig = config.derivativeGenericSignature;
57+
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
58+
paramIndices, original->getInterfaceType()->castTo<AnyFunctionType>());
59+
// If all indices in `parameterIndices` are in `daParameterIndices`, and
60+
// it has fewer indices than our current candidate and a primitive VJP,
61+
// then `attr` is our new candidate.
6162
//
6263
// NOTE(TF-642): `attr` may come from a un-partial-applied function and
6364
// have larger capacity than the desired indices. We expect this logic to
6465
// go away when `partial_apply` supports `@differentiable` callees.
65-
if (attrParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
66-
original->getASTContext(), attrParameterIndices->getCapacity())) &&
66+
if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
67+
original->getASTContext(), silParameterIndices->getCapacity())) &&
6768
// fewer parameters than before
68-
(!minimalParameterIndices ||
69-
attrParameterIndices->getNumIndices() <
70-
minimalParameterIndices->getNumIndices())) {
71-
minimalAttr = attr;
72-
minimalParameterIndices = attrParameterIndices;
69+
(!minimalSILParameterIndices ||
70+
silParameterIndices->getNumIndices() <
71+
minimalSILParameterIndices->getNumIndices())) {
72+
minimalASTParameterIndices = paramIndices;
73+
minimalSILParameterIndices = silParameterIndices;
74+
derivativeGenericSignature = derivativeGenSig;
7375
}
7476
}
75-
return minimalAttr;
77+
return minimalASTParameterIndices;
7678
}
7779

7880
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
@@ -88,19 +90,17 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
8890
if (!originalAFD)
8991
return nullptr;
9092

91-
IndexSubset *minimalParameterIndices = nullptr;
92-
const auto *minimalAttr = getMinimalASTDifferentiableAttr(
93-
originalAFD, parameterIndices, minimalParameterIndices);
94-
95-
// TODO(TF-835): This will also need to search all `@differentiating`
96-
// attributes after we stop synthesizing `@differentiable` attributes for
97-
// `@differentiating` attributes.
98-
99-
if (!minimalAttr)
93+
IndexSubset *minimalASTParameterIndices = nullptr;
94+
IndexSubset *minimalSILParameterIndices = nullptr;
95+
GenericSignature derivativeGenericSignature;
96+
if (!findMinimalDerivativeConfiguration(
97+
originalAFD, parameterIndices, minimalASTParameterIndices,
98+
minimalSILParameterIndices, derivativeGenericSignature)) {
10099
return nullptr;
100+
}
101101

102-
AutoDiffConfig minimalConfig(minimalParameterIndices, resultIndices,
103-
minimalAttr->getDerivativeGenericSignature());
102+
AutoDiffConfig minimalConfig(minimalSILParameterIndices, resultIndices,
103+
derivativeGenericSignature);
104104

105105
auto *existingWitness = module.lookUpDifferentiabilityWitness(
106106
{original->getName(), minimalConfig});

0 commit comments

Comments
 (0)