Skip to content

Commit 35beeca

Browse files
authored
Merge pull request #79808 from rintaro/astgen-autodiff
[ASTGen] Generate AutoDiff attributes
2 parents cc14548 + 017c0d9 commit 35beeca

File tree

8 files changed

+585
-12
lines changed

8 files changed

+585
-12
lines changed

include/swift/AST/ASTBridging.h

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ template<typename T> class ArrayRef;
3333
}
3434

3535
namespace swift {
36+
enum class AccessorKind;
3637
class AvailabilityDomain;
3738
class Argument;
3839
class ASTContext;
@@ -44,6 +45,7 @@ class DeclNameLoc;
4445
class DeclNameRef;
4546
class DiagnosticArgument;
4647
class DiagnosticEngine;
48+
enum class DifferentiabilityKind : uint8_t;
4749
class Fingerprint;
4850
class Identifier;
4951
class IfConfigClauseRangeInfo;
@@ -55,6 +57,7 @@ enum class MacroRole : uint32_t;
5557
class MacroIntroducedDeclName;
5658
enum class MacroIntroducedDeclNameKind;
5759
enum class ParamSpecifier : uint8_t;
60+
class ParsedAutoDiffParameter;
5861
enum class PlatformKind : uint8_t;
5962
class ProtocolConformanceRef;
6063
class RegexLiteralPatternFeature;
@@ -502,6 +505,13 @@ struct BridgedPatternBindingEntry {
502505
BridgedNullablePatternBindingInitializer initContext;
503506
};
504507

508+
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedAccessorKind {
509+
#define ACCESSOR(ID) BridgedAccessorKind##ID,
510+
#include "swift/AST/AccessorKinds.def"
511+
};
512+
513+
swift::AccessorKind unbridged(BridgedAccessorKind kind);
514+
505515
//===----------------------------------------------------------------------===//
506516
// MARK: Diagnostic Engine
507517
//===----------------------------------------------------------------------===//
@@ -745,6 +755,59 @@ struct BridgedAvailabilityDomain {
745755
bool isNull() const { return opaque == nullptr; };
746756
};
747757

758+
//===----------------------------------------------------------------------===//
759+
// MARK: AutoDiff
760+
//===----------------------------------------------------------------------===//
761+
762+
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedDifferentiabilityKind {
763+
BridgedDifferentiabilityKindNonDifferentiable = 0,
764+
BridgedDifferentiabilityKindForward = 1,
765+
BridgedDifferentiabilityKindReverse = 2,
766+
BridgedDifferentiabilityKindNormal = 3,
767+
BridgedDifferentiabilityKindLinear = 4,
768+
};
769+
770+
swift::DifferentiabilityKind unbridged(BridgedDifferentiabilityKind cKind);
771+
772+
class BridgedParsedAutoDiffParameter {
773+
private:
774+
BridgedSourceLoc loc;
775+
enum Kind {
776+
Named,
777+
Ordered,
778+
Self,
779+
} kind;
780+
union Value {
781+
BridgedIdentifier name;
782+
unsigned index;
783+
784+
Value(BridgedIdentifier name) : name(name) {}
785+
Value(unsigned index) : index(index) {}
786+
Value() : name() {}
787+
} value;
788+
789+
BridgedParsedAutoDiffParameter(BridgedSourceLoc loc, Kind kind, Value value)
790+
: loc(loc), kind(kind), value(value) {}
791+
792+
public:
793+
SWIFT_NAME("forNamed(_:loc:)")
794+
static BridgedParsedAutoDiffParameter forNamed(BridgedIdentifier name,
795+
BridgedSourceLoc loc) {
796+
return BridgedParsedAutoDiffParameter(loc, Kind::Named, name);
797+
}
798+
SWIFT_NAME("forOrdered(_:loc:)")
799+
static BridgedParsedAutoDiffParameter forOrdered(size_t index,
800+
BridgedSourceLoc loc) {
801+
return BridgedParsedAutoDiffParameter(loc, Kind::Ordered, index);
802+
}
803+
SWIFT_NAME("forSelf(loc:)")
804+
static BridgedParsedAutoDiffParameter forSelf(BridgedSourceLoc loc) {
805+
return BridgedParsedAutoDiffParameter(loc, Kind::Self, {});
806+
}
807+
808+
swift::ParsedAutoDiffParameter unbridged() const;
809+
};
810+
748811
//===----------------------------------------------------------------------===//
749812
// MARK: DeclAttributes
750813
//===----------------------------------------------------------------------===//
@@ -879,6 +942,30 @@ BridgedCustomAttr BridgedCustomAttr_createParsed(
879942
BridgedNullableCustomAttributeInitializer cInitContext,
880943
BridgedNullableArgumentList cArgumentList);
881944

945+
SWIFT_NAME("BridgedDerivativeAttr.createParsed(_:atLoc:range:baseType:"
946+
"originalName:originalNameLoc:accessorKind:params:)")
947+
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
948+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
949+
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
950+
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
951+
BridgedAccessorKind cAccessorKind, BridgedArrayRef cParams);
952+
953+
SWIFT_NAME("BridgedDerivativeAttr.createParsed(_:atLoc:range:baseType:"
954+
"originalName:originalNameLoc:params:)")
955+
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
956+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
957+
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
958+
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
959+
BridgedArrayRef cParams);
960+
961+
SWIFT_NAME("BridgedDifferentiableAttr.createParsed(_:atLoc:range:kind:params:"
962+
"genericWhereClause:)")
963+
BridgedDifferentiableAttr BridgedDifferentiableAttr_createParsed(
964+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
965+
BridgedSourceRange cRange, BridgedDifferentiabilityKind cKind,
966+
BridgedArrayRef cParams,
967+
BridgedNullableTrailingWhereClause cGenericWhereClause);
968+
882969
SWIFT_NAME("BridgedDocumentationAttr.createParsed(_:atLoc:range:metadata:"
883970
"accessLevel:)")
884971
BridgedDocumentationAttr BridgedDocumentationAttr_createParsed(
@@ -1260,6 +1347,15 @@ BridgedSILGenNameAttr BridgedSILGenNameAttr_createParsed(
12601347
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
12611348
BridgedSourceRange cRange, BridgedStringRef cName, bool isRaw);
12621349

1350+
SWIFT_NAME(
1351+
"BridgedTransposeAttr.createParsed(_:atLoc:range:baseType:originalName:"
1352+
"originalNameLoc:params:)")
1353+
BridgedTransposeAttr BridgedTransposeAttr_createParsed(
1354+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
1355+
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
1356+
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
1357+
BridgedArrayRef cParams);
1358+
12631359
SWIFT_NAME(
12641360
"BridgedUnavailableFromAsyncAttr.createParsed(_:atLoc:range:message:)")
12651361
BridgedUnavailableFromAsyncAttr BridgedUnavailableFromAsyncAttr_createParsed(
@@ -1285,11 +1381,6 @@ enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedStaticSpelling {
12851381
BridgedStaticSpellingClass
12861382
};
12871383

1288-
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedAccessorKind {
1289-
#define ACCESSOR(ID) BridgedAccessorKind##ID,
1290-
#include "swift/AST/AccessorKinds.def"
1291-
};
1292-
12931384
struct BridgedAccessorRecord {
12941385
BridgedSourceLoc lBraceLoc;
12951386
BridgedArrayRef accessors;
@@ -2438,6 +2529,13 @@ enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedExecutionTypeAttrExecutionKind {
24382529
BridgedExecutionTypeAttrExecutionKind_Caller
24392530
};
24402531

2532+
SWIFT_NAME("BridgedDifferentiableTypeAttr.createParsed(_:atLoc:nameLoc:"
2533+
"parensRange:kind:kindLoc:)")
2534+
BridgedDifferentiableTypeAttr BridgedDifferentiableTypeAttr_createParsed(
2535+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
2536+
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
2537+
BridgedDifferentiabilityKind cKind, BridgedSourceLoc cKindLoc);
2538+
24412539
SWIFT_NAME("BridgedExecutionTypeAttr.createParsed(_:atLoc:nameLoc:parensRange:"
24422540
"behavior:behaviorLoc:)")
24432541
BridgedExecutionTypeAttr BridgedExecutionTypeAttr_createParsed(

lib/AST/Bridging/DeclAttributeBridging.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,47 @@
1414

1515
#include "swift/AST/ASTContext.h"
1616
#include "swift/AST/Attr.h"
17+
#include "swift/AST/AutoDiff.h"
1718
#include "swift/AST/Expr.h"
1819
#include "swift/AST/Identifier.h"
1920
#include "swift/Basic/Assertions.h"
2021

2122
using namespace swift;
2223

24+
//===----------------------------------------------------------------------===//
25+
// MARK: AutoDiff
26+
//===----------------------------------------------------------------------===//
27+
28+
DifferentiabilityKind unbridged(BridgedDifferentiabilityKind cKind) {
29+
switch (cKind) {
30+
case BridgedDifferentiabilityKindNonDifferentiable:
31+
return DifferentiabilityKind::NonDifferentiable;
32+
case BridgedDifferentiabilityKindForward:
33+
return DifferentiabilityKind::Forward;
34+
case BridgedDifferentiabilityKindReverse:
35+
return DifferentiabilityKind::Reverse;
36+
case BridgedDifferentiabilityKindNormal:
37+
return DifferentiabilityKind::Normal;
38+
case BridgedDifferentiabilityKindLinear:
39+
return DifferentiabilityKind::Linear;
40+
}
41+
llvm_unreachable("unhandled enum value");
42+
}
43+
44+
ParsedAutoDiffParameter BridgedParsedAutoDiffParameter::unbridged() const {
45+
switch (kind) {
46+
case Kind::Named:
47+
return ParsedAutoDiffParameter::getNamedParameter(loc.unbridged(),
48+
value.name.unbridged());
49+
case Kind::Ordered:
50+
return ParsedAutoDiffParameter::getOrderedParameter(loc.unbridged(),
51+
value.index);
52+
case Kind::Self:
53+
return ParsedAutoDiffParameter::getSelfParameter(loc.unbridged());
54+
}
55+
llvm_unreachable("unhandled enum value");
56+
}
57+
2358
//===----------------------------------------------------------------------===//
2459
// MARK: DeclAttributes
2560
//===----------------------------------------------------------------------===//
@@ -221,6 +256,62 @@ BridgedCustomAttr BridgedCustomAttr_createParsed(
221256
cInitContext.unbridged(), cArgumentList.unbridged());
222257
}
223258

259+
BridgedDerivativeAttr BridgedDerivativeAttr_createParsedImpl(
260+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
261+
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
262+
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
263+
std::optional<BridgedAccessorKind> cAccessorKind, BridgedArrayRef cParams) {
264+
std::optional<AccessorKind> accessorKind;
265+
if (cAccessorKind)
266+
accessorKind = unbridged(*cAccessorKind);
267+
SmallVector<ParsedAutoDiffParameter, 2> params;
268+
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
269+
params.push_back(elem.unbridged());
270+
271+
return DerivativeAttr::create(cContext.unbridged(),
272+
/*implicit=*/false, cAtLoc.unbridged(),
273+
cRange.unbridged(), cBaseType.unbridged(),
274+
DeclNameRefWithLoc{cOriginalName.unbridged(),
275+
cOriginalNameLoc.unbridged(),
276+
accessorKind},
277+
params);
278+
}
279+
280+
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
281+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
282+
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
283+
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
284+
BridgedAccessorKind cAccessorKind, BridgedArrayRef cParams) {
285+
return BridgedDerivativeAttr_createParsedImpl(
286+
cContext, cAtLoc, cRange, cBaseType, cOriginalName, cOriginalNameLoc,
287+
cAccessorKind, cParams);
288+
}
289+
290+
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
291+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
292+
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
293+
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
294+
BridgedArrayRef cParams) {
295+
return BridgedDerivativeAttr_createParsedImpl(
296+
cContext, cAtLoc, cRange, cBaseType, cOriginalName, cOriginalNameLoc,
297+
/*cAccessorKind=*/std::nullopt, cParams);
298+
}
299+
300+
BridgedDifferentiableAttr BridgedDifferentiableAttr_createParsed(
301+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
302+
BridgedSourceRange cRange, BridgedDifferentiabilityKind cKind,
303+
BridgedArrayRef cParams,
304+
BridgedNullableTrailingWhereClause cGenericWhereClause) {
305+
SmallVector<ParsedAutoDiffParameter, 2> params;
306+
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
307+
params.push_back(elem.unbridged());
308+
309+
return DifferentiableAttr::create(cContext.unbridged(), /*implicit=*/false,
310+
cAtLoc.unbridged(), cRange.unbridged(),
311+
unbridged(cKind), params,
312+
cGenericWhereClause.unbridged());
313+
}
314+
224315
BridgedDynamicReplacementAttr BridgedDynamicReplacementAttr_createParsed(
225316
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
226317
BridgedSourceLoc cAttrNameLoc, BridgedSourceLoc cLParenLoc,
@@ -752,6 +843,24 @@ BridgedSILGenNameAttr BridgedSILGenNameAttr_createParsed(
752843
cRange.unbridged(), /*Implicit=*/false);
753844
}
754845

846+
BridgedTransposeAttr BridgedTransposeAttr_createParsed(
847+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
848+
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
849+
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
850+
BridgedArrayRef cParams) {
851+
SmallVector<ParsedAutoDiffParameter, 2> params;
852+
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
853+
params.push_back(elem.unbridged());
854+
855+
return TransposeAttr::create(
856+
cContext.unbridged(),
857+
/*implicit=*/false, cAtLoc.unbridged(), cRange.unbridged(),
858+
cBaseType.unbridged(),
859+
DeclNameRefWithLoc{cOriginalName.unbridged(), cOriginalNameLoc.unbridged(),
860+
/*AccessorKind=*/std::nullopt},
861+
params);
862+
}
863+
755864
BridgedUnavailableFromAsyncAttr BridgedUnavailableFromAsyncAttr_createParsed(
756865
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
757866
BridgedSourceRange cRange, BridgedStringRef cMessage) {

lib/AST/Bridging/DeclBridging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ static StaticSpellingKind unbridged(BridgedStaticSpelling kind) {
111111
return static_cast<StaticSpellingKind>(kind);
112112
}
113113

114-
static AccessorKind unbridged(BridgedAccessorKind kind) {
114+
AccessorKind unbridged(BridgedAccessorKind kind) {
115115
return static_cast<AccessorKind>(kind);
116116
}
117117

lib/AST/Bridging/TypeAttributeBridging.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ BridgedConventionTypeAttr BridgedConventionTypeAttr_createParsed(
7878
{cClangType.unbridged(), cClangTypeLoc.unbridged()});
7979
}
8080

81+
BridgedDifferentiableTypeAttr BridgedDifferentiableTypeAttr_createParsed(
82+
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
83+
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
84+
BridgedDifferentiabilityKind cKind, BridgedSourceLoc cKindLoc) {
85+
return new (cContext.unbridged()) DifferentiableTypeAttr(
86+
cAtLoc.unbridged(), cNameLoc.unbridged(), cParensRange.unbridged(),
87+
{unbridged(cKind), cKindLoc.unbridged()});
88+
}
89+
8190
BridgedExecutionTypeAttr BridgedExecutionTypeAttr_createParsed(
8291
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
8392
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,

0 commit comments

Comments
 (0)