Skip to content

Commit bf47403

Browse files
authored
[AutoDiff] Serialize and print @derivative and @transpose accessor kind. (#32839)
Serialize and print the optional accessor kind in `@derivative` and `@transpose` attributes. Resolves TF-1293.
1 parent bfcf12b commit bf47403

File tree

8 files changed

+99
-18
lines changed

8 files changed

+99
-18
lines changed

include/swift/AST/Attr.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,13 +1715,6 @@ class OriginallyDefinedInAttr: public DeclAttribute {
17151715
}
17161716
};
17171717

1718-
/// A declaration name with location.
1719-
struct DeclNameRefWithLoc {
1720-
DeclNameRef Name;
1721-
DeclNameLoc Loc;
1722-
Optional<AccessorKind> AccessorKind;
1723-
};
1724-
17251718
/// Attribute that marks a function as differentiable.
17261719
///
17271720
/// Examples:
@@ -1847,6 +1840,18 @@ class DifferentiableAttr final
18471840
}
18481841
};
18491842

1843+
/// A declaration name with location.
1844+
struct DeclNameRefWithLoc {
1845+
/// The declaration name.
1846+
DeclNameRef Name;
1847+
/// The declaration name location.
1848+
DeclNameLoc Loc;
1849+
/// An optional accessor kind.
1850+
Optional<AccessorKind> AccessorKind;
1851+
1852+
void print(ASTPrinter &Printer) const;
1853+
};
1854+
18501855
/// The `@derivative(of:)` attribute registers a function as a derivative of
18511856
/// another function-like declaration: a 'func', 'init', 'subscript', or 'var'
18521857
/// computed property declaration.

lib/AST/Attr.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,9 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
10521052
Printer.printAttrName("@derivative");
10531053
Printer << "(of: ";
10541054
auto *attr = cast<DerivativeAttr>(this);
1055-
Printer << attr->getOriginalFunctionName().Name;
1055+
if (auto *baseType = attr->getBaseTypeRepr())
1056+
baseType->print(Printer, Options);
1057+
attr->getOriginalFunctionName().print(Printer);
10561058
auto *derivative = cast<AbstractFunctionDecl>(D);
10571059
auto diffParamsString = getDifferentiationParametersClauseString(
10581060
derivative, attr->getParameterIndices(), attr->getParsedParameters(),
@@ -1067,7 +1069,9 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
10671069
Printer.printAttrName("@transpose");
10681070
Printer << "(of: ";
10691071
auto *attr = cast<TransposeAttr>(this);
1070-
Printer << attr->getOriginalFunctionName().Name;
1072+
if (auto *baseType = attr->getBaseTypeRepr())
1073+
baseType->print(Printer, Options);
1074+
attr->getOriginalFunctionName().print(Printer);
10711075
auto *transpose = cast<AbstractFunctionDecl>(D);
10721076
auto transParamsString = getDifferentiationParametersClauseString(
10731077
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
@@ -1719,6 +1723,12 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(
17191723
return original->getGenericEnvironment();
17201724
}
17211725

1726+
void DeclNameRefWithLoc::print(ASTPrinter &Printer) const {
1727+
Printer << Name;
1728+
if (AccessorKind)
1729+
Printer << '.' << getAccessorLabel(*AccessorKind);
1730+
}
1731+
17221732
void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
17231733
bool omitWrtClause) const {
17241734
StreamPrinter P(OS);

lib/AST/AutoDiff.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,14 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
422422
}
423423
}
424424

425+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
426+
const DeclNameRefWithLoc &name) {
427+
os << name.Name;
428+
if (auto accessorKind = name.AccessorKind)
429+
os << '.' << getAccessorLabel(*accessorKind);
430+
return os;
431+
}
432+
425433
bool swift::operator==(const TangentPropertyInfo::Error &lhs,
426434
const TangentPropertyInfo::Error &rhs) {
427435
if (lhs.kind != rhs.kind)

lib/Serialization/Deserialization.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4371,16 +4371,26 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43714371
case decls_block::Derivative_DECL_ATTR: {
43724372
bool isImplicit;
43734373
uint64_t origNameId;
4374+
bool hasAccessorKind;
4375+
uint64_t rawAccessorKind;
43744376
DeclID origDeclId;
43754377
uint64_t rawDerivativeKind;
43764378
ArrayRef<uint64_t> parameters;
43774379

43784380
serialization::decls_block::DerivativeDeclAttrLayout::readRecord(
4379-
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
4380-
parameters);
4381+
scratch, isImplicit, origNameId, hasAccessorKind, rawAccessorKind,
4382+
origDeclId, rawDerivativeKind, parameters);
4383+
4384+
Optional<AccessorKind> accessorKind = None;
4385+
if (hasAccessorKind) {
4386+
auto maybeAccessorKind = getActualAccessorKind(rawAccessorKind);
4387+
if (!maybeAccessorKind)
4388+
MF.fatal();
4389+
accessorKind = *maybeAccessorKind;
4390+
}
43814391

43824392
DeclNameRefWithLoc origName{DeclNameRef(MF.getDeclBaseName(origNameId)),
4383-
DeclNameLoc(), None};
4393+
DeclNameLoc(), accessorKind};
43844394
auto derivativeKind =
43854395
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
43864396
if (!derivativeKind)

lib/Serialization/ModuleFormat.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t SWIFTMODULE_VERSION_MINOR = 563; // unchecked_value_cast
58+
const uint16_t SWIFTMODULE_VERSION_MINOR = 564; // `@derivative` attribute accessor kind
5959

6060
/// A standard hash seed used for all string hashes in a serialized module.
6161
///
@@ -1848,6 +1848,8 @@ namespace decls_block {
18481848
Derivative_DECL_ATTR,
18491849
BCFixed<1>, // Implicit flag.
18501850
IdentifierIDField, // Original name.
1851+
BCFixed<1>, // Has original accessor kind?
1852+
AccessorKindField, // Original accessor kind.
18511853
DeclIDField, // Original function declaration.
18521854
AutoDiffDerivativeFunctionKindField, // Derivative function kind.
18531855
BCArray<BCFixed<1>> // Differentiation parameter indices' bitvector.

lib/Serialization/Serialization.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,19 +2431,25 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
24312431
assert(attr->getOriginalFunction(ctx) &&
24322432
"`@derivative` attribute should have original declaration set "
24332433
"during construction or parsing");
2434-
auto origName = attr->getOriginalFunctionName().Name.getBaseName();
2434+
auto origDeclNameRef = attr->getOriginalFunctionName();
2435+
auto origName = origDeclNameRef.Name.getBaseName();
24352436
IdentifierID origNameId = S.addDeclBaseNameRef(origName);
24362437
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction(ctx));
24372438
auto derivativeKind =
24382439
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
2440+
uint8_t rawAccessorKind = 0;
2441+
auto origAccessorKind = origDeclNameRef.AccessorKind;
2442+
if (origAccessorKind)
2443+
rawAccessorKind = uint8_t(getStableAccessorKind(*origAccessorKind));
24392444
auto *parameterIndices = attr->getParameterIndices();
24402445
assert(parameterIndices && "Parameter indices must be resolved");
24412446
SmallVector<bool, 4> paramIndicesVector;
24422447
for (unsigned i : range(parameterIndices->getCapacity()))
24432448
paramIndicesVector.push_back(parameterIndices->contains(i));
24442449
DerivativeDeclAttrLayout::emitRecord(
24452450
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
2446-
origDeclID, derivativeKind, paramIndicesVector);
2451+
origAccessorKind.hasValue(), rawAccessorKind, origDeclID,
2452+
derivativeKind, paramIndicesVector);
24472453
return;
24482454
}
24492455

test/AutoDiff/Serialization/derivative_attr.swift

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@ extension S {
5656
(self, { $0 })
5757
}
5858

59+
// Note: qualified name base types are not yet serialized and are not printed
60+
// when round-tripping.
61+
5962
// CHECK: @derivative(of: instanceMethod, wrt: (self, x))
60-
@derivative(of: instanceMethod, wrt: (self, x))
63+
@derivative(of: S.instanceMethod, wrt: (self, x))
6164
func derivativeInstanceMethodWrtAll(_ x: S) -> (value: S, differential: (S, S) -> S) {
6265
(self, { (dself, dx) in self })
6366
}
@@ -81,19 +84,39 @@ extension S {
8184

8285
extension S {
8386
var computedProperty: S {
84-
self
87+
get { self }
88+
set {}
8589
}
8690

8791
// CHECK: @derivative(of: computedProperty, wrt: self)
8892
@derivative(of: computedProperty, wrt: self)
8993
func derivativeProperty() -> (value: S, differential: (S) -> S) {
9094
(self, { $0 })
9195
}
96+
97+
// CHECK: @derivative(of: computedProperty.get, wrt: self)
98+
@derivative(of: computedProperty.get, wrt: self)
99+
func derivativePropertyGetter() -> (value: S, pullback: (S) -> S) {
100+
fatalError()
101+
}
102+
103+
// CHECK: @derivative(of: computedProperty.set, wrt: (self, newValue))
104+
@derivative(of: computedProperty.set, wrt: (self, newValue))
105+
mutating func derivativePropertySetter(_ newValue: S) -> (
106+
value: (), pullback: (inout S) -> S
107+
) {
108+
fatalError()
109+
}
92110
}
93111

94112
// Test subscripts.
95113

96114
extension S {
115+
subscript() -> S {
116+
get { self }
117+
set {}
118+
}
119+
97120
subscript<T: Differentiable>(x: T) -> S {
98121
self
99122
}
@@ -103,4 +126,18 @@ extension S {
103126
func derivativeSubscript<T: Differentiable>(x: T) -> (value: S, differential: (S) -> S) {
104127
(self, { $0 })
105128
}
129+
130+
// CHECK: @derivative(of: subscript.get, wrt: self)
131+
@derivative(of: subscript.get, wrt: self)
132+
func derivativeSubscriptGetter() -> (value: S, pullback: (S) -> S) {
133+
fatalError()
134+
}
135+
136+
// CHECK: @derivative(of: subscript.set, wrt: (self, newValue))
137+
@derivative(of: subscript.set, wrt: (self, newValue))
138+
mutating func derivativeSubscriptSetter(_ newValue: S) -> (
139+
value: (), pullback: (inout S) -> S
140+
) {
141+
fatalError()
142+
}
106143
}

test/AutoDiff/Serialization/transpose_attr.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ extension S {
5050
self + t
5151
}
5252

53+
// Note: qualified name base types are not yet serialized and are not printed
54+
// when round-tripping.
55+
5356
// CHECK: @transpose(of: instanceMethod, wrt: self)
54-
@transpose(of: instanceMethod, wrt: self)
57+
@transpose(of: S.instanceMethod, wrt: self)
5558
static func transposeInstanceMethodWrtSelf(_ other: S, t: S) -> S {
5659
other + t
5760
}

0 commit comments

Comments
 (0)