Skip to content

Commit 15da94f

Browse files
authored
Merge pull request #36015 from rxwei/74380324-mangle-diff-witness-keys
2 parents 9acf214 + e494df2 commit 15da94f

File tree

53 files changed

+618
-290
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+618
-290
lines changed

docs/ABI/Mangling.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ Globals
167167
global ::= protocol-conformance protocol 'Wb' // base protocol witness table accessor
168168
global ::= type protocol-conformance 'Wl' // lazy protocol witness table accessor
169169

170+
global ::= global generic-signature? 'WJ' DIFFERENTIABILITY-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // differentiability witness
171+
170172
global ::= type 'WV' // value witness table
171173
global ::= entity 'Wvd' // field offset
172174
global ::= entity 'WC' // resilient enum tag index

docs/SIL.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,12 +1590,15 @@ Differentiability Witnesses
15901590
sil-differentiability-witness ::=
15911591
'sil_differentiability_witness'
15921592
sil-linkage?
1593+
'[' differentiability-kind ']'
15931594
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
15941595
'[' 'results' sil-differentiability-witness-function-index-list ']'
15951596
generic-parameter-clause?
15961597
sil-function-name ':' sil-type
15971598
sil-differentiability-witness-body?
15981599

1600+
differentiability-kind ::= 'forward' | 'reverse' | 'normal' | 'linear'
1601+
15991602
sil-differentiability-witness-body ::=
16001603
'{' sil-differentiability-witness-entry?
16011604
sil-differentiability-witness-entry? '}'
@@ -1625,7 +1628,7 @@ based on the key.
16251628

16261629
::
16271630

1628-
sil_differentiability_witness hidden [parameters 0] [results 0] <T where T : Differentiable> @id : $@convention(thin) (T) -> T {
1631+
sil_differentiability_witness hidden [normal] [parameters 0] [results 0] <T where T : Differentiable> @id : $@convention(thin) (T) -> T {
16291632
jvp: @id_jvp : $@convention(thin) (T) -> (T, @owned @callee_guaranteed (T.TangentVector) -> T.TangentVector)
16301633
vjp: @id_vjp : $@convention(thin) (T) -> (T, @owned @callee_guaranteed (T.TangentVector) -> T.TangentVector)
16311634
}
@@ -7066,6 +7069,7 @@ differentiability_witness_function
70667069
sil-instruction ::=
70677070
'differentiability_witness_function'
70687071
'[' sil-differentiability-witness-function-kind ']'
7072+
'[' differentiability-kind ']'
70697073
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
70707074
'[' 'results' sil-differentiability-witness-function-index-list ']'
70717075
generic-parameter-clause?
@@ -7074,7 +7078,7 @@ differentiability_witness_function
70747078
sil-differentiability-witness-function-kind ::= 'jvp' | 'vjp' | 'transpose'
70757079
sil-differentiability-witness-function-index-list ::= [0-9]+ (' ' [0-9]+)*
70767080

7077-
differentiability_witness_function [jvp] [parameters 0] [results 0] \
7081+
differentiability_witness_function [vjp] [reverse] [parameters 0] [results 0] \
70787082
<T where T: Differentiable> @foo : $(T) -> T
70797083

70807084
Looks up a differentiability witness function (JVP, VJP, or transpose) for
@@ -7086,6 +7090,7 @@ look up: ``[jvp]``, ``[vjp]``, or ``[transpose]``.
70867090
The remaining components identify the SIL differentiability witness:
70877091

70887092
- Original function name.
7093+
- Differentiability kind.
70897094
- Parameter indices.
70907095
- Result indices.
70917096
- Witness generic parameter clause (optional). When parsing SIL, the parsed

include/swift/AST/ASTMangler.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ class ASTMangler : public Mangler {
204204
CanType fromType, CanType toType, GenericSignature signature,
205205
AutoDiffLinearMapKind linearMapKind);
206206

207+
/// Mangle a SIL differentiability witness.
208+
std::string mangleSILDifferentiabilityWitness(StringRef originalName,
209+
DifferentiabilityKind kind,
210+
AutoDiffConfig config);
211+
207212
/// Mangle the AutoDiff generated declaration for the given:
208213
/// - Generated declaration kind: linear map struct or branching trace enum.
209214
/// - Mangled original function name.
@@ -217,14 +222,6 @@ class ASTMangler : public Mangler {
217222
AutoDiffLinearMapKind linearMapKind,
218223
AutoDiffConfig config);
219224

220-
/// Mangle a SIL differentiability witness key:
221-
/// - Mangled original function name.
222-
/// - Parameter indices.
223-
/// - Result indices.
224-
/// - Derivative generic signature (optional).
225-
std::string
226-
mangleSILDifferentiabilityWitnessKey(SILDifferentiabilityWitnessKey key);
227-
228225
std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
229226
GenericSignature signature,
230227
CanType baseType,

include/swift/AST/AutoDiff.h

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ struct AutoDiffConfig {
195195
IndexSubset *resultIndices;
196196
GenericSignature derivativeGenericSignature;
197197

198+
/*implicit*/ AutoDiffConfig() = default;
198199
/*implicit*/ AutoDiffConfig(
199200
IndexSubset *parameterIndices, IndexSubset *resultIndices,
200201
GenericSignature derivativeGenericSignature = GenericSignature())
@@ -545,10 +546,20 @@ struct TangentPropertyInfo {
545546

546547
void simple_display(llvm::raw_ostream &OS, TangentPropertyInfo info);
547548

548-
/// The key type used for uniquing `SILDifferentiabilityWitness` in
549-
/// `SILModule`: original function name, parameter indices, result indices, and
550-
/// derivative generic signature.
551-
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
549+
/// The key type used for uniquing `SILDifferentiabilityWitness` in `SILModule`.
550+
struct SILDifferentiabilityWitnessKey {
551+
StringRef originalFunctionName;
552+
DifferentiabilityKind kind;
553+
AutoDiffConfig config;
554+
555+
void print(llvm::raw_ostream &s = llvm::outs()) const;
556+
};
557+
558+
inline llvm::raw_ostream &operator<<(
559+
llvm::raw_ostream &s, const SILDifferentiabilityWitnessKey &key) {
560+
key.print(s);
561+
return s;
562+
}
552563

553564
/// Returns `true` iff differentiable programming is enabled.
554565
bool isDifferentiableProgrammingEnabled(SourceFile &SF);
@@ -676,6 +687,9 @@ getAutoDiffFunctionKind(AutoDiffDerivativeFunctionKind kind);
676687

677688
AutoDiffFunctionKind getAutoDiffFunctionKind(AutoDiffLinearMapKind kind);
678689

690+
MangledDifferentiabilityKind
691+
getMangledDifferentiabilityKind(DifferentiabilityKind kind);
692+
679693
} // end namespace autodiff
680694
} // end namespace swift
681695

@@ -688,6 +702,8 @@ using swift::GenericSignature;
688702
using swift::IndexSubset;
689703
using swift::SILAutoDiffDerivativeFunctionKey;
690704
using swift::SILFunctionType;
705+
using swift::DifferentiabilityKind;
706+
using swift::SILDifferentiabilityWitnessKey;
691707

692708
template <typename T> struct DenseMapInfo;
693709

@@ -760,8 +776,8 @@ template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
760776
};
761777

762778
template <> struct DenseMapInfo<SILAutoDiffDerivativeFunctionKey> {
763-
static bool isEqual(const SILAutoDiffDerivativeFunctionKey lhs,
764-
const SILAutoDiffDerivativeFunctionKey rhs) {
779+
static bool isEqual(const SILAutoDiffDerivativeFunctionKey &lhs,
780+
const SILAutoDiffDerivativeFunctionKey &rhs) {
765781
return lhs.originalType == rhs.originalType &&
766782
lhs.parameterIndices == rhs.parameterIndices &&
767783
lhs.resultIndices == rhs.resultIndices &&
@@ -803,6 +819,36 @@ template <> struct DenseMapInfo<SILAutoDiffDerivativeFunctionKey> {
803819
}
804820
};
805821

822+
template <> struct DenseMapInfo<SILDifferentiabilityWitnessKey> {
823+
static bool isEqual(const SILDifferentiabilityWitnessKey &lhs,
824+
const SILDifferentiabilityWitnessKey &rhs) {
825+
return DenseMapInfo<StringRef>::isEqual(
826+
lhs.originalFunctionName, rhs.originalFunctionName) &&
827+
DenseMapInfo<unsigned>::isEqual(
828+
(unsigned)lhs.kind, (unsigned)rhs.kind) &&
829+
DenseMapInfo<AutoDiffConfig>::isEqual(lhs.config, rhs.config);
830+
}
831+
832+
static inline SILDifferentiabilityWitnessKey getEmptyKey() {
833+
return {DenseMapInfo<StringRef>::getEmptyKey(),
834+
(DifferentiabilityKind)DenseMapInfo<unsigned>::getEmptyKey(),
835+
DenseMapInfo<AutoDiffConfig>::getEmptyKey()};
836+
}
837+
838+
static inline SILDifferentiabilityWitnessKey getTombstoneKey() {
839+
return {DenseMapInfo<StringRef>::getTombstoneKey(),
840+
(DifferentiabilityKind)DenseMapInfo<unsigned>::getTombstoneKey(),
841+
DenseMapInfo<AutoDiffConfig>::getTombstoneKey()};
842+
}
843+
844+
static unsigned getHashValue(const SILDifferentiabilityWitnessKey &val) {
845+
return hash_combine(
846+
DenseMapInfo<StringRef>::getHashValue(val.originalFunctionName),
847+
DenseMapInfo<unsigned>::getHashValue((unsigned)val.kind),
848+
DenseMapInfo<AutoDiffConfig>::getHashValue(val.config));
849+
}
850+
};
851+
806852
} // end namespace llvm
807853

808854
#endif // SWIFT_AST_AUTODIFF_H

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,9 @@ ERROR(sil_witness_protocol_conformance_not_found,none,
676676
// SIL differentiability witnesses
677677
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
678678
"expected '%0' in differentiability witness", (StringRef))
679+
ERROR(sil_diff_witness_unknown_kind,PointsToFirstBadToken,
680+
"unknonwn differentiability kind '%0'; expected 'forward', 'reverse', "
681+
"'normal', or 'linear'", (StringRef))
679682
ERROR(sil_diff_witness_serialized_declaration,none,
680683
"differentiability witness declaration should not be serialized", ())
681684
ERROR(sil_diff_witness_undefined,PointsToFirstBadToken,

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ NODE(AutoDiffFunctionKind)
313313
NODE(AutoDiffSelfReorderingReabstractionThunk)
314314
NODE(AutoDiffSubsetParametersThunk)
315315
NODE(AutoDiffDerivativeVTableThunk)
316+
NODE(DifferentiabilityWitness)
316317
NODE(IndexSubset)
317318

318319
#undef CONTEXT_NODE

include/swift/Demangling/Demangler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ class Demangler : public NodeFactory {
573573
NodePointer demangleAutoDiffFunctionKind();
574574
NodePointer demangleAutoDiffSubsetParametersThunk();
575575
NodePointer demangleAutoDiffSelfReorderingReabstractionThunk();
576+
NodePointer demangleDifferentiabilityWitness();
576577
NodePointer demangleIndexSubset();
577578

578579
bool demangleBoundGenerics(Vector<NodePointer> &TypeListList,

include/swift/IRGen/Linking.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ class LinkEntity {
380380
/// ProtocolConformance*.
381381
ProtocolWitnessTableLazyCacheVariable,
382382

383-
/// A SIL differentiability witness.
383+
/// A SIL differentiability witness. The pointer is a
384+
/// SILDifferentiabilityWitness*.
384385
DifferentiabilityWitness,
385386

386387
// Everything following this is a type kind.

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class SILDifferentiabilityWitness
4747
SILLinkage Linkage;
4848
/// The original function.
4949
SILFunction *OriginalFunction;
50+
/// The differentiability kind.
51+
DifferentiabilityKind Kind;
5052
/// The derivative configuration: parameter indices, result indices, and
5153
/// derivative generic signature (optional). The derivative generic signature
5254
/// may contain same-type requirements such that all generic parameters are
@@ -69,27 +71,32 @@ class SILDifferentiabilityWitness
6971

7072
SILDifferentiabilityWitness(
7173
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
72-
IndexSubset *parameterIndices, IndexSubset *resultIndices,
73-
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
74-
bool isDeclaration, bool isSerialized, const DeclAttribute *attribute)
74+
DifferentiabilityKind kind, IndexSubset *parameterIndices,
75+
IndexSubset *resultIndices, GenericSignature derivativeGenSig,
76+
SILFunction *jvp, SILFunction *vjp, bool isDeclaration, bool isSerialized,
77+
const DeclAttribute *attribute)
7578
: Module(module), Linkage(linkage), OriginalFunction(originalFunction),
79+
Kind(kind),
7680
Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
7781
JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
78-
IsSerialized(isSerialized), Attribute(attribute) {}
82+
IsSerialized(isSerialized), Attribute(attribute) {
83+
assert(kind != DifferentiabilityKind::NonDifferentiable);
84+
}
7985

8086
public:
8187
static SILDifferentiabilityWitness *
8288
createDeclaration(SILModule &module, SILLinkage linkage,
83-
SILFunction *originalFunction,
89+
SILFunction *originalFunction, DifferentiabilityKind kind,
8490
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8591
GenericSignature derivativeGenSig,
8692
const DeclAttribute *attribute = nullptr);
8793

8894
static SILDifferentiabilityWitness *createDefinition(
8995
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
90-
IndexSubset *parameterIndices, IndexSubset *resultIndices,
91-
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
92-
bool isSerialized, const DeclAttribute *attribute = nullptr);
96+
DifferentiabilityKind kind, IndexSubset *parameterIndices,
97+
IndexSubset *resultIndices, GenericSignature derivativeGenSig,
98+
SILFunction *jvp, SILFunction *vjp, bool isSerialized,
99+
const DeclAttribute *attribute = nullptr);
93100

94101
void convertToDefinition(SILFunction *jvp, SILFunction *vjp,
95102
bool isSerialized);
@@ -98,6 +105,7 @@ class SILDifferentiabilityWitness
98105
SILModule &getModule() const { return Module; }
99106
SILLinkage getLinkage() const { return Linkage; }
100107
SILFunction *getOriginalFunction() const { return OriginalFunction; }
108+
DifferentiabilityKind getKind() const { return Kind; }
101109
const AutoDiffConfig &getConfig() const { return Config; }
102110
IndexSubset *getParameterIndices() const { return Config.parameterIndices; }
103111
IndexSubset *getResultIndices() const { return Config.resultIndices; }

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
250250
/// \param parameterIndices must be lowered to SIL.
251251
/// \param resultIndices must be lowered to SIL.
252252
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
253-
SILModule &module, SILFunction *original, IndexSubset *parameterIndices,
254-
IndexSubset *resultIndices);
253+
SILModule &module, SILFunction *original, DifferentiabilityKind kind,
254+
IndexSubset *parameterIndices, IndexSubset *resultIndices);
255255

256256
} // end namespace autodiff
257257

0 commit comments

Comments
 (0)