Skip to content

[AutoDiff upstream] Add SIL differentiability witnesses. #29623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,66 @@ variable cannot be used as l-value, i.e. the reference to the object cannot be
modified. As a consequence the variable cannot be accessed with ``global_addr``
but only with ``global_value``.

Differentiability Witnesses
~~~~~~~~~~~~~~~~~~~~~~~~~~~
::

decl ::= sil-differentiability-witness
sil-differentiability-witness ::=
'sil_differentiability_witness'
sil-linkage?
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
'[' 'results' sil-differentiability-witness-function-index-list ']'
generic-parameter-clause?
sil-function-name ':' sil-type
sil-differentiability-witness-body?

sil-differentiability-witness-body ::=
'{' sil-differentiability-witness-entry?
sil-differentiability-witness-entry? '}'

sil-differentiability-witness-entry ::=
sil-differentiability-witness-entry-kind ':'
sil-entry-name ':' sil-type

sil-differentiability-witness-entry-kind ::= 'jvp' | 'vjp'

SIL encodes function differentiability via differentiability witnesses.

Differentiability witnesses map a "key" (including an "original" SIL function)
to derivative SIL functions.

Differentiability witnesses are keyed by the following:

- An "original" SIL function name.
- Differentiability parameter indices.
- Differentiability result indices.
- A generic parameter clause, representing differentiability generic
requirements.

Differentiability witnesses may have a body, specifying derivative functions for
the key. Verification checks that derivative functions have the expected type
based on the key.

::

sil_differentiability_witness hidden [parameters 0] [results 0] <T where T : Differentiable> @id : $@convention(thin) (T) -> T {
jvp: @id_jvp : $@convention(thin) (T) -> (T, @owned @callee_guaranteed (T.TangentVector) -> T.TangentVector)
vjp: @id_vjp : $@convention(thin) (T) -> (T, @owned @callee_guaranteed (T.TangentVector) -> T.TangentVector)
}

During SILGen, differentiability witnesses are emitted for the following:

- `@differentiable` declaration attributes.
- `@derivative` declaration attributes. Registered derivative functions
become differentiability witness entries.

The SIL differentiation transform canonicalizes differentiability witnesses,
filling in missing entries.

Differentiability witness entries are accessed via the
`differentiability_witness_function` instruction.

Dataflow Errors
---------------

Expand Down
10 changes: 9 additions & 1 deletion include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,15 @@ class ASTMangler : public Mangler {
Type FromType, Type ToType,
Type SelfType,
ModuleDecl *Module);


/// Mangle a SIL differentiability witness key:
/// - Mangled original function name.
/// - Parameter indices.
/// - Result indices.
/// - Derivative generic signature (optional).
std::string
mangleSILDifferentiabilityWitnessKey(SILDifferentiabilityWitnessKey key);

std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
GenericSignature signature,
CanType baseType,
Expand Down
5 changes: 5 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ class TangentSpace {
NominalTypeDecl *getNominal() const;
};

/// The key type used for uniquing `SILDifferentiabilityWitness` in
/// `SILModule`: original function name, parameter indices, result indices, and
/// derivative generic signature.
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;

/// Automatic differentiation utility namespace.
namespace autodiff {

Expand Down
26 changes: 26 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,18 @@ ERROR(sil_witness_assoc_conf_not_found,none,
ERROR(sil_witness_protocol_conformance_not_found,none,
"sil protocol conformance not found", ())

// SIL differentiability witnesses
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
"expected '%0' in differentiability witness", (StringRef))
ERROR(sil_diff_witness_serialized_declaration,none,
"differentiability witness declaration should not be serialized", ())
ERROR(sil_diff_witness_undefined,PointsToFirstBadToken,
"reference to undefined differentiability witness", ())
ERROR(sil_diff_witness_invalid_generic_signature,PointsToFirstBadToken,
"expected witness generic signature '%0' does not have same generic "
"parameters as original function generic signature '%1'",
(StringRef, StringRef))

// SIL Coverage Map
ERROR(sil_coverage_invalid_hash, none,
"expected coverage hash", ())
Expand Down Expand Up @@ -1577,6 +1589,20 @@ ERROR(diff_params_clause_expected_parameter_unnamed,PointsToFirstBadToken,
ERROR(autodiff_attr_expected_original_decl_name,PointsToFirstBadToken,
"expected an original function name", ())

// SIL autodiff
ERROR(sil_autodiff_expected_lsquare,PointsToFirstBadToken,
"expected '[' to start the %0", (StringRef))
ERROR(sil_autodiff_expected_rsquare,PointsToFirstBadToken,
"expected ']' to complete the %0", (StringRef))
ERROR(sil_autodiff_expected_index_list,PointsToFirstBadToken,
"expected a space-separated list of indices, e.g. '0 1'", ())
ERROR(sil_autodiff_expected_index_list_label,PointsToFirstBadToken,
"expected label '%0' in index list", (StringRef))
ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
"expected the index of a parameter to differentiate with respect to", ())
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
"expected the index of a result to differentiate from", ())

//------------------------------------------------------------------------------
// MARK: Generics parsing diagnostics
//------------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions include/swift/Parse/ParseSILSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace swift {
virtual bool parseSILGlobal(Parser &P) = 0;
virtual bool parseSILWitnessTable(Parser &P) = 0;
virtual bool parseSILDefaultWitnessTable(Parser &P) = 0;
virtual bool parseSILDifferentiabilityWitness(Parser &P) = 0;
virtual bool parseSILCoverageMap(Parser &P) = 0;
virtual bool parseSILProperty(Parser &P) = 0;
virtual bool parseSILScope(Parser &P) = 0;
Expand Down
166 changes: 166 additions & 0 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
//===--- SILDifferentiabilityWitness.h - Differentiability witnesses ------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file defines the SILDifferentiabilityWitness class, which maps an
// original SILFunction and derivative configuration (parameter indices, result
// indices, derivative generic signature) to derivative functions (JVP and VJP).
//
// SIL differentiability witnesses are generated from the `@differentiable`
// and `@derivative` AST declaration attributes.
//
// Differentiability witnesses are canonicalized by the SIL differentiation
// transform, which fills in missing derivative functions.
//
//===----------------------------------------------------------------------===//

#ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
#define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H

#include "swift/AST/Attr.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/GenericSignature.h"
#include "swift/SIL/SILAllocated.h"
#include "swift/SIL/SILLinkage.h"
#include "llvm/ADT/ilist.h"
#include "llvm/ADT/ilist_node.h"

namespace swift {

class SILPrintContext;

class SILDifferentiabilityWitness
: public llvm::ilist_node<SILDifferentiabilityWitness>,
public SILAllocated<SILDifferentiabilityWitness> {
private:
/// The module which contains the differentiability witness.
SILModule &Module;
/// The linkage of the differentiability witness.
SILLinkage Linkage;
/// The original function.
SILFunction *OriginalFunction;
/// The derivative configuration: parameter indices, result indices, and
/// derivative generic signature (optional). The derivative generic signature
/// may contain same-type requirements such that all generic parameters are
/// bound to concrete types.
AutoDiffConfig Config;
/// The JVP (Jacobian-vector products) derivative function.
SILFunction *JVP;
/// The VJP (vector-Jacobian products) derivative function.
SILFunction *VJP;
/// Whether or not this differentiability witness is a declaration.
bool IsDeclaration;
/// Whether or not this differentiability witness is serialized, which allows
/// devirtualization from another module.
bool IsSerialized;
/// The AST `@differentiable` or `@derivative` attribute from which the
/// differentiability witness is generated. Used for diagnostics.
/// Null if the differentiability witness is parsed from SIL or if it is
/// deserialized.
const DeclAttribute *Attribute = nullptr;

SILDifferentiabilityWitness(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
bool isDeclaration, bool isSerialized, const DeclAttribute *attribute)
: Module(module), Linkage(linkage), OriginalFunction(originalFunction),
Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
IsSerialized(isSerialized), Attribute(attribute) {}

public:
static SILDifferentiabilityWitness *
createDeclaration(SILModule &module, SILLinkage linkage,
SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenSig,
const DeclAttribute *attribute = nullptr);

static SILDifferentiabilityWitness *createDefinition(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
bool isSerialized, const DeclAttribute *attribute = nullptr);

void convertToDefinition(SILFunction *jvp, SILFunction *vjp,
bool isSerialized);

SILDifferentiabilityWitnessKey getKey() const;
SILModule &getModule() const { return Module; }
SILLinkage getLinkage() const { return Linkage; }
SILFunction *getOriginalFunction() const { return OriginalFunction; }
const AutoDiffConfig &getConfig() const { return Config; }
IndexSubset *getParameterIndices() const { return Config.parameterIndices; }
IndexSubset *getResultIndices() const { return Config.resultIndices; }
GenericSignature getDerivativeGenericSignature() const {
return Config.derivativeGenericSignature;
}
SILFunction *getJVP() const { return JVP; }
SILFunction *getVJP() const { return VJP; }
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
return JVP;
case AutoDiffDerivativeFunctionKind::VJP:
return VJP;
}
}
void setJVP(SILFunction *jvp) { JVP = jvp; }
void setVJP(SILFunction *vjp) { VJP = vjp; }
void setDerivative(AutoDiffDerivativeFunctionKind kind,
SILFunction *derivative) {
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
JVP = derivative;
break;
case AutoDiffDerivativeFunctionKind::VJP:
VJP = derivative;
break;
}
}
bool isDeclaration() const { return IsDeclaration; }
bool isDefinition() const { return !IsDeclaration; }
bool isSerialized() const { return IsSerialized; }
const DeclAttribute *getAttribute() const { return Attribute; }

/// Verify that the differentiability witness is well-formed.
void verify(const SILModule &module) const;

void print(llvm::raw_ostream &os, bool verbose = false) const;
void dump() const;
};

} // end namespace swift

namespace llvm {

//===----------------------------------------------------------------------===//
// ilist_traits for SILDifferentiabilityWitness
//===----------------------------------------------------------------------===//

template <>
struct ilist_traits<::swift::SILDifferentiabilityWitness>
: public ilist_node_traits<::swift::SILDifferentiabilityWitness> {
using SILDifferentiabilityWitness = ::swift::SILDifferentiabilityWitness;

public:
static void deleteNode(SILDifferentiabilityWitness *DW) {
DW->~SILDifferentiabilityWitness();
}

private:
void createNode(const SILDifferentiabilityWitness &);
};

} // namespace llvm

#endif // SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
44 changes: 44 additions & 0 deletions include/swift/SIL/SILModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "swift/SIL/SILCoverageMap.h"
#include "swift/SIL/SILDeclRef.h"
#include "swift/SIL/SILDefaultWitnessTable.h"
#include "swift/SIL/SILDifferentiabilityWitness.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILGlobalVariable.h"
#include "swift/SIL/SILPrintContext.h"
Expand Down Expand Up @@ -113,6 +114,8 @@ class SILModule {
using PropertyListType = llvm::ilist<SILProperty>;
using WitnessTableListType = llvm::ilist<SILWitnessTable>;
using DefaultWitnessTableListType = llvm::ilist<SILDefaultWitnessTable>;
using DifferentiabilityWitnessListType =
llvm::ilist<SILDifferentiabilityWitness>;
using CoverageMapCollectionType =
llvm::MapVector<StringRef, SILCoverageMap *>;

Expand All @@ -131,6 +134,7 @@ class SILModule {
friend SILBasicBlock;
friend SILCoverageMap;
friend SILDefaultWitnessTable;
friend SILDifferentiabilityWitness;
friend SILFunction;
friend SILGlobalVariable;
friend SILLayout;
Expand Down Expand Up @@ -194,6 +198,17 @@ class SILModule {
/// The list of SILDefaultWitnessTables in the module.
DefaultWitnessTableListType defaultWitnessTables;

/// Lookup table for SIL differentiability witnesses, keyed by mangled name.
llvm::StringMap<SILDifferentiabilityWitness *> DifferentiabilityWitnessMap;

/// Lookup table for SILDifferentiabilityWitnesses, keyed by original
/// function name.
llvm::StringMap<llvm::SmallVector<SILDifferentiabilityWitness *, 1>>
DifferentiabilityWitnessesByFunction;

/// The list of SILDifferentiabilityWitnesses in the module.
DifferentiabilityWitnessListType differentiabilityWitnesses;

/// Declarations which are externally visible.
///
/// These are method declarations which are referenced from inlinable
Expand Down Expand Up @@ -455,6 +470,24 @@ class SILModule {
return {defaultWitnessTables.begin(), defaultWitnessTables.end()};
}

using differentiability_witness_iterator = DifferentiabilityWitnessListType::iterator;
using differentiability_witness_const_iterator = DifferentiabilityWitnessListType::const_iterator;
DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() { return differentiabilityWitnesses; }
const DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() const { return differentiabilityWitnesses; } differentiability_witness_iterator differentiability_witness_begin() { return differentiabilityWitnesses.begin(); }
differentiability_witness_iterator differentiability_witness_end() { return differentiabilityWitnesses.end(); }
differentiability_witness_const_iterator differentiability_witness_begin() const { return differentiabilityWitnesses.begin(); }
differentiability_witness_const_iterator differentiability_witness_end() const { return differentiabilityWitnesses.end(); }
iterator_range<differentiability_witness_iterator>
getDifferentiabilityWitnesses() {
return {differentiabilityWitnesses.begin(),
differentiabilityWitnesses.end()};
}
iterator_range<differentiability_witness_const_iterator>
getDifferentiabilityWitnesses() const {
return {differentiabilityWitnesses.begin(),
differentiabilityWitnesses.end()};
}

void addExternallyVisibleDecl(ValueDecl *decl) {
externallyVisible.insert(decl);
}
Expand Down Expand Up @@ -591,6 +624,17 @@ class SILModule {
/// hierarchy of \p Class.
SILFunction *lookUpFunctionInVTable(ClassDecl *Class, SILDeclRef Member);

/// Look up the differentiability witness with the given name.
SILDifferentiabilityWitness *lookUpDifferentiabilityWitness(StringRef name);

/// Look up the differentiability witness corresponding to the given key.
SILDifferentiabilityWitness *
lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);

/// Look up the differentiability witness corresponding to the given function.
llvm::ArrayRef<SILDifferentiabilityWitness *>
lookUpDifferentiabilityWitnessesForFunction(StringRef name);

// Given a protocol, attempt to create a default witness table declaration
// for it.
SILDefaultWitnessTable *
Expand Down
Loading