Skip to content

[AutoDiff] Directly SILGen @derivative attributes to diff witnesses. #28621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ namespace swift {
class IndexSubset;
// SWIFT_ENABLE_TENSORFLOW
struct AutoDiffConfig;
class VectorSpace;
struct AutoDiffDerivativeFunctionKind;
class DerivativeAttr;
class DifferentiableAttr;
class VectorSpace;
// SWIFT_ENABLE_TENSORFLOW END

enum class KnownProtocolKind : uint8_t;
Expand Down Expand Up @@ -290,11 +292,26 @@ class ASTContext final {
/// Cache of autodiff-associated vector spaces.
llvm::DenseMap<Type, Optional<VectorSpace>> AutoDiffVectorSpaces;

/// Cache of `@differentiable` attributes keyed by parameter indices. This
/// helps us diagnose multiple `@differentiable`s that are with respect to the
/// same set of parameters.
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
/// diagnose duplicate `@differentiable` attributes for the same key.
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
// signature as a key is possible. It requires derivative generic signature
// mangling to avoid name collisions for SIL derivative functions with the
// same parameter indices but different derivative generic signatures.
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
DifferentiableAttrs;

/// Cache of `@derivative` attributes keyed by parameter indices and
/// derivative function kind. Used to diagnose duplicate `@derivative`
/// attributes for the same key.
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
// signature as a key is possible. It requires derivative generic signature
// mangling to avoid name collisions for SIL derivative functions with the
// same parameter indices but different derivative generic signatures.
llvm::DenseMap<
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
DerivativeAttr *>
DerivativeAttrs;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use the derivative config list that you created in the previous serialization PR, so that we do not have to maintain another list?

Could the derivative config list actually also replace DifferentiableAttrs?

Copy link
Contributor Author

@dan-zheng dan-zheng Dec 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered the same. The derivative function configuration doesn't currently store whether the configuration came from (a @differentiable attribute or @derivative attribute), and that information is important for users of DifferentiableAttrs and DerivativeAttrs, which check duplicate attributes of a specific kind, not just configurations.

I haven't thought deeply about this. I'll file an issue tracking this question if it isn't resolved by the time this PR is merged.


Edit: filed TF-1042 to track Investigate removing/moving ASTContext::{Differentiable,Derivative}Attrs.

Using AbstractFunctionDecl::getDerivativeFunctionConfigurations to detect duplicate @differentiable and @derivative attributes may be significant for cross-file duplicate derivative registration (TF-1021).

If you import a derivative for func foo, you shouldn't be able to register a new derivative for func foo with the same configuration.

This requires changing AbstractFunctionDecl::getDerivativeFunctionConfigurations to return more information than ArrayRef<AutoDiffConfig>.

  • Minimally, it needs to return an OptionSet per AutoDiffConfig, specifying where the AutoDiffConfig came from:
    • @differentiable attribute
    • @derivative JVP
    • @derivative VJP
    • Any combination of the above (three bits)
  • For good "duplicate attribute" diagnostics, it also needs to store sth from which we can get an @differentiable/@derivative attribute SourceLoc.
dup.swift:1:2: error: duplicate '@differentiable' attribute with same parameters
@differentiable
~^~~~~~~~~~~~~~
dup.swift:2:2: note: other attribute declared here << need SourceLoc to generate this note
@differentiable
 ^

// SWIFT_ENABLE_TENSORFLOW END

private:
Expand Down
8 changes: 4 additions & 4 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ DECL_ATTR(differentiable, Differentiable,
91)
DECL_ATTR(derivative, Derivative,
OnFunc | LongAttribute | AllowMultipleAttributes |
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
NotSerialized, 92)
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
92)
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
OnAccessor | OnFunc | OnConstructor | OnSubscript |
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
Expand Down Expand Up @@ -542,8 +542,8 @@ DECL_ATTR(quoted, Quoted,
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
DECL_ATTR(differentiating, Differentiating,
OnFunc | LongAttribute | AllowMultipleAttributes |
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
NotSerialized, 98)
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
98)
// SWIFT_ENABLE_TENSORFLOW END

#undef TYPE_ATTR
Expand Down
17 changes: 13 additions & 4 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ class DifferentiableAttr final

explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *indices,
IndexSubset *parameterIndices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenericSignature);
Expand All @@ -1855,9 +1855,10 @@ class DifferentiableAttr final
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

static DifferentiableAttr *create(Decl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, IndexSubset *indices,
static DifferentiableAttr *create(AbstractFunctionDecl *original,
bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig);
Expand Down Expand Up @@ -1947,6 +1948,8 @@ class DerivativeAttr final
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The derivative function kind (JVP or VJP), resolved by the type checker.
Optional<AutoDiffDerivativeFunctionKind> Kind = None;

explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
Expand Down Expand Up @@ -1975,6 +1978,12 @@ class DerivativeAttr final
OriginalFunction = decl;
}

AutoDiffDerivativeFunctionKind getDerivativeKind() const {
assert(Kind && "Derivative function kind has not yet been resolved");
return *Kind;
}
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }

/// The parsed differentiation parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
Expand Down
22 changes: 22 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ struct AutoDiffConfig {
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
const AutoDiffDerivativeFunctionKind kind;
IndexSubset *const parameterIndices;
// TODO(TF-680): Mangle derivative generic signature requirements as well.

AutoDiffDerivativeFunctionIdentifier(
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
Expand Down Expand Up @@ -508,6 +509,27 @@ template<> struct DenseMapInfo<AutoDiffConfig> {
}
};

template<> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
static AutoDiffDerivativeFunctionKind getEmptyKey() {
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
DenseMapInfo<unsigned>::getEmptyKey());
}

static AutoDiffDerivativeFunctionKind getTombstoneKey() {
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
DenseMapInfo<unsigned>::getTombstoneKey());
}

static unsigned getHashValue(const AutoDiffDerivativeFunctionKind &Val) {
return DenseMapInfo<unsigned>::getHashValue(Val);
}

static bool isEqual(const AutoDiffDerivativeFunctionKind &LHS,
const AutoDiffDerivativeFunctionKind &RHS) {
return LHS == RHS;
}
};

template<> struct DenseMapInfo<SILAutoDiffIndices> {
static SILAutoDiffIndices getEmptyKey() {
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };
Expand Down
20 changes: 14 additions & 6 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,8 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
Printer.printAttrName("@derivative");
Printer << "(of: ";
auto *attr = cast<DerivativeAttr>(this);
auto *derivative = cast<AbstractFunctionDecl>(D);
Printer << attr->getOriginalFunctionName().Name;
auto *derivative = cast<AbstractFunctionDecl>(D);
auto diffParamsString = getDifferentiationParametersClauseString(
derivative, attr->getParameterIndices(), attr->getParsedParameters());
if (!diffParamsString.empty())
Expand All @@ -963,8 +963,8 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
Printer.printAttrName("@transpose");
Printer << '(';
auto *attr = cast<TransposeAttr>(this);
auto *transpose = cast<AbstractFunctionDecl>(D);
Printer << attr->getOriginalFunctionName().Name;
auto *transpose = cast<AbstractFunctionDecl>(D);
auto transParamsString = getTransposedParametersClauseString(
transpose, attr->getParameterIndices(), attr->getParsedParameters());
if (!transParamsString.empty())
Expand Down Expand Up @@ -1492,16 +1492,24 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
}

DifferentiableAttr *
DifferentiableAttr::create(Decl *original, bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *indices, Optional<DeclNameWithLoc> jvp,
DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig) {
auto &ctx = original->getASTContext();
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
alignof(DifferentiableAttr));
// Register derivative function configuration for the given original
// declaration.
// NOTE(TF-1038): `@differentiable` attributes currently always have
// effective result indices `{0}` (the first and only result index).
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
original->addDerivativeFunctionConfiguration(
{parameterIndices, resultIndices, derivativeGenSig});
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
linear, indices, std::move(jvp),
linear, parameterIndices, std::move(jvp),
std::move(vjp), derivativeGenSig);
}

Expand Down
55 changes: 42 additions & 13 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,26 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
diffAttr->getDerivativeGenericSignature());
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
}
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
SILFunction *jvp = nullptr;
SILFunction *vjp = nullptr;
switch (derivAttr->getDerivativeKind()) {
case AutoDiffDerivativeFunctionKind::JVP:
jvp = F;
break;
case AutoDiffDerivativeFunctionKind::VJP:
vjp = F;
break;
}
auto *origAFD = derivAttr->getOriginalFunction();
auto *origFn = getFunction(SILDeclRef(origAFD), NotForDefinition);
auto derivativeGenSig = AFD->getGenericSignature();
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
derivativeGenSig);
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
derivAttr);
}
};
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
if (accessor->isGetter())
Expand All @@ -790,21 +810,22 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
void SILGenModule::emitDifferentiabilityWitness(
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
const DeclAttribute *diffAttr) {
const DeclAttribute *attr) {
assert(isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
auto origSilFnType = originalFunction->getLoweredFunctionType();
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
config.parameterIndices, origFnType);
auto *silParamIndices =
autodiff::getLoweredParameterIndices(config.parameterIndices, origFnType);
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
// parameters corresponding to captured variables. These parameters do not
// appear in the type of `origFnType`.
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
// take `CaptureInfo` into account.
if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity())
loweredParamIndices = loweredParamIndices->extendingCapacity(
if (origSilFnType->getNumParameters() > silParamIndices->getCapacity())
silParamIndices = silParamIndices->extendingCapacity(
getASTContext(), origSilFnType->getNumParameters());
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
SILAutoDiffIndices indices(/*source*/ 0, silParamIndices);

// Self reordering thunk is necessary if wrt at least two parameters,
// including self.
Expand All @@ -818,14 +839,22 @@ void SILGenModule::emitDifferentiabilityWitness(
};
bool reorderSelf = shouldReorderSelf();

// Create new SIL differentiability witness.
// Get or create new SIL differentiability witness.
// Witness already exists when there are two `@derivative` attributes (JVP and
// VJP) for the same derivative function configuration.
// Witness JVP and VJP are set below.
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
M, originalFunction->getLinkage(), originalFunction, loweredParamIndices,
config.resultIndices, config.derivativeGenericSignature,
/*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
diffAttr);
AutoDiffConfig silConfig(silParamIndices, config.resultIndices,
config.derivativeGenericSignature);
SILDifferentiabilityWitnessKey key{originalFunction->getName(), silConfig};
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
if (!diffWitness) {
diffWitness = SILDifferentiabilityWitness::createDefinition(
M, originalFunction->getLinkage(), originalFunction,
silConfig.parameterIndices, silConfig.resultIndices,
config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
attr);
}

// Set derivative function in differentiability witness.
auto setDerivativeInDifferentiabilityWitness =
Expand Down
6 changes: 2 additions & 4 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,6 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
auto memberAssocContextualType =
parentDC->mapTypeIntoContext(memberAssocInterfaceType);
newMember->setInterfaceType(memberAssocInterfaceType);
// newMember->setType(memberAssocContextualType);
Pattern *memberPattern =
new (C) NamedPattern(newMember, /*implicit*/ true);
memberPattern->setType(memberAssocContextualType);
Expand Down Expand Up @@ -623,10 +622,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
derivativeGenSig = extDecl->getGenericSignature();
auto *diffableAttr = DifferentiableAttr::create(
getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
/*linear*/ false, {}, None, None, derivativeGenSig);
/*linear*/ false, /*parameterIndices*/ IndexSubset::get(C, 1, {0}),
/*jvp*/ None, /*vjp*/ None, derivativeGenSig);
member->getAttrs().add(diffableAttr);
// Set getter `@differentiable` attribute parameter indices.
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));
}
}

Expand Down
Loading