Skip to content

[AutoDiff] Implement cross-file lookup of derivatives #58644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,10 @@ class DerivativeAttr final
friend TrailingObjects;
friend class DerivativeAttrOriginalDeclRequest;

/// The declaration on which the `@derivative` attribute is declared.
/// May not be a valid declaration for `@derivative` attributes.
/// Resolved during parsing and deserialization.
Decl *OriginalDeclaration = nullptr;
/// The base type for the referenced original declaration. This field is
/// non-null only for parsed attributes that reference a qualified original
/// declaration. This field is not serialized; type-checking uses it to
Expand Down Expand Up @@ -1991,6 +1995,12 @@ class DerivativeAttr final
DeclNameRefWithLoc original,
IndexSubset *parameterIndices);

Decl *getOriginalDeclaration() const { return OriginalDeclaration; }

/// Sets the original declaration on which this attribute is declared.
/// Should only be used by parsing and deserialization.
void setOriginalDeclaration(Decl *originalDeclaration);

TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
DeclNameRefWithLoc getOriginalFunctionName() const {
return OriginalFunctionName;
Expand Down
7 changes: 5 additions & 2 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6265,8 +6265,11 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;

public:
/// Get all derivative function configurations.
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
/// Get all derivative function configurations. If `lookInNonPrimarySources`
/// is true then lookup is done in non-primary sources as well. Note that
/// such lookup might end in cycles if done during sema stages.
ArrayRef<AutoDiffConfig>
getDerivativeFunctionConfigurations(bool lookInNonPrimarySources = true);

/// Add the given derivative function configuration.
void addDerivativeFunctionConfiguration(const AutoDiffConfig &config);
Expand Down
7 changes: 7 additions & 0 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,13 @@ void DerivativeAttr::setOriginalFunctionResolver(
ResolverContextData = resolverContextData;
}

void DerivativeAttr::setOriginalDeclaration(Decl *originalDeclaration) {
assert(originalDeclaration && "Original declaration must be non-null");
assert(!OriginalDeclaration &&
"Original declaration cannot have already been set");
OriginalDeclaration = originalDeclaration;
}

TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName,
Expand Down
34 changes: 33 additions & 1 deletion lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "swift/AST/ASTContext.h"
#include "swift/AST/ASTWalker.h"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/Attr.h"
#include "swift/AST/CaptureInfo.h"
#include "swift/AST/DiagnosticEngine.h"
#include "swift/AST/DiagnosticsSema.h"
Expand Down Expand Up @@ -8310,7 +8311,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
}

ArrayRef<AutoDiffConfig>
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimarySources) {
prepareDerivativeFunctionConfigurations();

// Resolve derivative function configurations from `@differentiable`
Expand All @@ -8333,6 +8334,37 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
ctx.loadDerivativeFunctionConfigurations(this, previousGeneration,
*DerivativeFunctionConfigs);
}

class DerivativeFinder : public ASTWalker {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by question: does @derivative attribute finding have an impact on incremental compilation?


When a file is updated, it and all files that depend on it need to be recompiled. Incremental compilation is fast when there are few dependencies between files.

I wonder if @derivative finding can cause incremental compilation to become asymptotically worse. Thinking through an example:

  1. a.swift defines func foo. b.swift exists in the same module, defining a func vjpFoo function with no attributes.
  2. Many files in the same module ask for the derivative of foo, e.g. by calling gradient(of: foo).

Initially, foo has no derivative functions registered via @derivative, so gradient(of: foo) triggers the compiler to automatically generate derivative functions.

  1. func vjpFoo inb.swift is now marked with @derivative(of: foo).

Since all files (a.swift, b.swift, gradient(of: foo) files) are in the same module, the addition of @derivative(of: foo) in b.swift means that foo now has a registered derivative, and all functions that call gradient(of: foo) need to be recompiled†.

Is the † recompilation behavior asymptotically worse than existing incremental compilation? It might not be, if the same amount of work (e.g. recompiling the entire module) was necessary previous to the implementation of correct @derivative attribute finding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(It may be that the impact of @derivative attribute finding on incremental compilation was introduced prior to this PR.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dan-zheng The complete recompilation is certainly necessary when optimizations are enabled. As derivatives might be inlined and subsequently optimized. So, effectively we'd need to throw out all the code and start afresh.

I do not like the existing getDerivativeFunctionConfigurations implementation as it has lots of side effects (even before this PR). Though lots of things are computed lazily here and there... Maybe cross-module lookup should be as a last-resort thing as it might be quite expensive on large modules – maybe we'd just iterate once and explicitly register all explicit derivatives.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The complete recompilation is certainly necessary when optimizations are enabled.

When optimizations are enabled you're building in whole-module mode anyways. Swift will refuse to perform the incremental build if you ask for -wmo as well.

An even simpler failure mode occurs here since this analysis is not going through name lookup. From the incremental build's perspective, there is no link between the derivative and the original.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When optimizations are enabled you're building in whole-module mode anyways.

I'm confused. Will you please elaborate? The original testcase from #55170 fails regardless whether -O is provided to swiftc.

const AbstractFunctionDecl *AFD;
public:
DerivativeFinder(const AbstractFunctionDecl *afd) : AFD(afd) {}

bool walkToDeclPre(Decl *D) override {
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
// Resolve derivative function configurations from `@derivative`
// attributes by type-checking them.
if (AFD->getName().matchesRef(
derAttr->getOriginalFunctionName().Name.getFullName())) {
(void)derAttr->getOriginalFunction(afd->getASTContext());
return false;
}
}
}

return true;
}
};

// Load derivative configurations from @derivative attributes defined in
// non-primary sources. Note that it might trigger lookup cycles if called
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This analysis needs to go through name lookup. If that's triggering cycles then we need to figure out why getDerivativeFunctionConfigurations is on the lookup path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It cannot by design. See the full description in #58644 (comment) There is no way you could discover custom derivatives via name lookup as name is not known and the original function does not know anything about custom derivatives.

// from inside Sema stages.
if (lookInNonPrimarySources) {
DerivativeFinder finder(this);
getParent()->walkContext(finder);
}

return DerivativeFunctionConfigs->getArrayRef();
}

Expand Down
2 changes: 2 additions & 0 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4455,6 +4455,8 @@ setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs,
Decl *D) {
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
for (auto *attr : attrs.getAttributes<DerivativeAttr>())
const_cast<DerivativeAttr *>(attr)->setOriginalDeclaration(D);
}

/// Parse a single syntactic declaration and return a list of decl
Expand Down
12 changes: 9 additions & 3 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4949,10 +4949,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
/// - Stores the attribute in `ASTContext::DerivativeAttrs`.
///
/// \returns true on error, false on success.
static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
DerivativeAttr *attr) {
static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
// Note: Implementation must be idempotent because it may be called multiple
// times for the same attribute.
Decl *D = attr->getOriginalDeclaration();
auto &Ctx = D->getASTContext();
auto &diags = Ctx.Diags;
// `@derivative` attribute requires experimental differentiable programming
// to be enabled.
Expand Down Expand Up @@ -5365,13 +5366,18 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
}

void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
if (typeCheckDerivativeAttr(Ctx, D, attr))
if (typeCheckDerivativeAttr(attr))
attr->setInvalid();
}

AbstractFunctionDecl *
DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator,
DerivativeAttr *attr) const {
// Try to resolve the original function.
if (attr->isValid() && attr->OriginalFunction.isNull())
if (typeCheckDerivativeAttr(attr))
attr->setInvalid();

// If the typechecker has resolved the original function, return it.
if (auto *FD = attr->OriginalFunction.dyn_cast<AbstractFunctionDecl *>())
return FD;
Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
bool foundExactConfig = false;
Optional<AutoDiffConfig> supersetConfig = None;
for (auto witnessConfig :
witnessAFD->getDerivativeFunctionConfigurations()) {
witnessAFD->getDerivativeFunctionConfigurations(
/*lookInNonPrimarySources*/ false)) {
// All the witness's derivative generic requirements must be satisfied
// by the requirement's derivative generic requirements OR by the
// conditional conformance requirements.
Expand Down
5 changes: 5 additions & 0 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "ModuleFile.h"
#include "ModuleFormat.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/Attr.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/DiagnosticsSema.h"
#include "swift/AST/Expr.h"
Expand Down Expand Up @@ -2590,6 +2591,10 @@ static void setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes(
diffAttr->setOriginalDeclaration(decl);
diffAttr->setParameterIndices(diffAttrParamIndicesMap[diffAttr]);
}
for (auto *attr : tempAttrs.getAttributes<DerivativeAttr>()) {
auto *derAttr = const_cast<DerivativeAttr *>(attr);
derAttr->setOriginalDeclaration(decl);
}
}

Decl *ModuleFile::getDecl(DeclID DID) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2780,7 +2780,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
auto abbrCode = S.DeclTypeAbbrCodes[DerivativeDeclAttrLayout::Code];
auto *attr = cast<DerivativeAttr>(DA);
auto &ctx = S.getASTContext();
assert(attr->getOriginalFunction(ctx) &&
assert(attr->getOriginalFunction(ctx) && attr->getOriginalDeclaration() &&
"`@derivative` attribute should have original declaration set "
"during construction or parsing");
auto origDeclNameRef = attr->getOriginalFunctionName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@ func crossFileDifferentiableAttr<T: Protocol>(
}

// TF-1272: Test original function with registered derivatives in other files.
// FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other
// files.
@differentiable(reverse)
func crossFileDerivativeAttr<T: Protocol>(
_ input: T
) -> T {
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
// No error expected
return input.identityDerivativeAttr()
}

Expand Down