Skip to content

[ASTGen] Generate AutoDiff attributes #79808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 103 additions & 5 deletions include/swift/AST/ASTBridging.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ template<typename T> class ArrayRef;
}

namespace swift {
enum class AccessorKind;
class AvailabilityDomain;
class Argument;
class ASTContext;
Expand All @@ -44,6 +45,7 @@ class DeclNameLoc;
class DeclNameRef;
class DiagnosticArgument;
class DiagnosticEngine;
enum class DifferentiabilityKind : uint8_t;
class Fingerprint;
class Identifier;
class IfConfigClauseRangeInfo;
Expand All @@ -55,6 +57,7 @@ enum class MacroRole : uint32_t;
class MacroIntroducedDeclName;
enum class MacroIntroducedDeclNameKind;
enum class ParamSpecifier : uint8_t;
class ParsedAutoDiffParameter;
enum class PlatformKind : uint8_t;
class ProtocolConformanceRef;
class RegexLiteralPatternFeature;
Expand Down Expand Up @@ -502,6 +505,13 @@ struct BridgedPatternBindingEntry {
BridgedNullablePatternBindingInitializer initContext;
};

enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedAccessorKind {
#define ACCESSOR(ID) BridgedAccessorKind##ID,
#include "swift/AST/AccessorKinds.def"
};

swift::AccessorKind unbridged(BridgedAccessorKind kind);

//===----------------------------------------------------------------------===//
// MARK: Diagnostic Engine
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -745,6 +755,59 @@ struct BridgedAvailabilityDomain {
bool isNull() const { return opaque == nullptr; };
};

//===----------------------------------------------------------------------===//
// MARK: AutoDiff
//===----------------------------------------------------------------------===//

enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedDifferentiabilityKind {
BridgedDifferentiabilityKindNonDifferentiable = 0,
BridgedDifferentiabilityKindForward = 1,
BridgedDifferentiabilityKindReverse = 2,
BridgedDifferentiabilityKindNormal = 3,
BridgedDifferentiabilityKindLinear = 4,
};

swift::DifferentiabilityKind unbridged(BridgedDifferentiabilityKind cKind);

class BridgedParsedAutoDiffParameter {
private:
BridgedSourceLoc loc;
enum Kind {
Named,
Ordered,
Self,
} kind;
union Value {
BridgedIdentifier name;
unsigned index;

Value(BridgedIdentifier name) : name(name) {}
Value(unsigned index) : index(index) {}
Value() : name() {}
} value;

BridgedParsedAutoDiffParameter(BridgedSourceLoc loc, Kind kind, Value value)
: loc(loc), kind(kind), value(value) {}

public:
SWIFT_NAME("forNamed(_:loc:)")
static BridgedParsedAutoDiffParameter forNamed(BridgedIdentifier name,
BridgedSourceLoc loc) {
return BridgedParsedAutoDiffParameter(loc, Kind::Named, name);
}
SWIFT_NAME("forOrdered(_:loc:)")
static BridgedParsedAutoDiffParameter forOrdered(size_t index,
BridgedSourceLoc loc) {
return BridgedParsedAutoDiffParameter(loc, Kind::Ordered, index);
}
SWIFT_NAME("forSelf(loc:)")
static BridgedParsedAutoDiffParameter forSelf(BridgedSourceLoc loc) {
return BridgedParsedAutoDiffParameter(loc, Kind::Self, {});
}

swift::ParsedAutoDiffParameter unbridged() const;
};

//===----------------------------------------------------------------------===//
// MARK: DeclAttributes
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -879,6 +942,30 @@ BridgedCustomAttr BridgedCustomAttr_createParsed(
BridgedNullableCustomAttributeInitializer cInitContext,
BridgedNullableArgumentList cArgumentList);

SWIFT_NAME("BridgedDerivativeAttr.createParsed(_:atLoc:range:baseType:"
"originalName:originalNameLoc:accessorKind:params:)")
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedAccessorKind cAccessorKind, BridgedArrayRef cParams);

SWIFT_NAME("BridgedDerivativeAttr.createParsed(_:atLoc:range:baseType:"
"originalName:originalNameLoc:params:)")
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedArrayRef cParams);

SWIFT_NAME("BridgedDifferentiableAttr.createParsed(_:atLoc:range:kind:params:"
"genericWhereClause:)")
BridgedDifferentiableAttr BridgedDifferentiableAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedDifferentiabilityKind cKind,
BridgedArrayRef cParams,
BridgedNullableTrailingWhereClause cGenericWhereClause);

SWIFT_NAME("BridgedDocumentationAttr.createParsed(_:atLoc:range:metadata:"
"accessLevel:)")
BridgedDocumentationAttr BridgedDocumentationAttr_createParsed(
Expand Down Expand Up @@ -1260,6 +1347,15 @@ BridgedSILGenNameAttr BridgedSILGenNameAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedStringRef cName, bool isRaw);

SWIFT_NAME(
"BridgedTransposeAttr.createParsed(_:atLoc:range:baseType:originalName:"
"originalNameLoc:params:)")
BridgedTransposeAttr BridgedTransposeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedArrayRef cParams);

SWIFT_NAME(
"BridgedUnavailableFromAsyncAttr.createParsed(_:atLoc:range:message:)")
BridgedUnavailableFromAsyncAttr BridgedUnavailableFromAsyncAttr_createParsed(
Expand All @@ -1285,11 +1381,6 @@ enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedStaticSpelling {
BridgedStaticSpellingClass
};

enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedAccessorKind {
#define ACCESSOR(ID) BridgedAccessorKind##ID,
#include "swift/AST/AccessorKinds.def"
};

struct BridgedAccessorRecord {
BridgedSourceLoc lBraceLoc;
BridgedArrayRef accessors;
Expand Down Expand Up @@ -2438,6 +2529,13 @@ enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedExecutionTypeAttrExecutionKind {
BridgedExecutionTypeAttrExecutionKind_Caller
};

SWIFT_NAME("BridgedDifferentiableTypeAttr.createParsed(_:atLoc:nameLoc:"
"parensRange:kind:kindLoc:)")
BridgedDifferentiableTypeAttr BridgedDifferentiableTypeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
BridgedDifferentiabilityKind cKind, BridgedSourceLoc cKindLoc);

SWIFT_NAME("BridgedExecutionTypeAttr.createParsed(_:atLoc:nameLoc:parensRange:"
"behavior:behaviorLoc:)")
BridgedExecutionTypeAttr BridgedExecutionTypeAttr_createParsed(
Expand Down
109 changes: 109 additions & 0 deletions lib/AST/Bridging/DeclAttributeBridging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,47 @@

#include "swift/AST/ASTContext.h"
#include "swift/AST/Attr.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Expr.h"
#include "swift/AST/Identifier.h"
#include "swift/Basic/Assertions.h"

using namespace swift;

//===----------------------------------------------------------------------===//
// MARK: AutoDiff
//===----------------------------------------------------------------------===//

DifferentiabilityKind unbridged(BridgedDifferentiabilityKind cKind) {
switch (cKind) {
case BridgedDifferentiabilityKindNonDifferentiable:
return DifferentiabilityKind::NonDifferentiable;
case BridgedDifferentiabilityKindForward:
return DifferentiabilityKind::Forward;
case BridgedDifferentiabilityKindReverse:
return DifferentiabilityKind::Reverse;
case BridgedDifferentiabilityKindNormal:
return DifferentiabilityKind::Normal;
case BridgedDifferentiabilityKindLinear:
return DifferentiabilityKind::Linear;
}
llvm_unreachable("unhandled enum value");
}

ParsedAutoDiffParameter BridgedParsedAutoDiffParameter::unbridged() const {
switch (kind) {
case Kind::Named:
return ParsedAutoDiffParameter::getNamedParameter(loc.unbridged(),
value.name.unbridged());
case Kind::Ordered:
return ParsedAutoDiffParameter::getOrderedParameter(loc.unbridged(),
value.index);
case Kind::Self:
return ParsedAutoDiffParameter::getSelfParameter(loc.unbridged());
}
llvm_unreachable("unhandled enum value");
}

//===----------------------------------------------------------------------===//
// MARK: DeclAttributes
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -221,6 +256,62 @@ BridgedCustomAttr BridgedCustomAttr_createParsed(
cInitContext.unbridged(), cArgumentList.unbridged());
}

BridgedDerivativeAttr BridgedDerivativeAttr_createParsedImpl(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
std::optional<BridgedAccessorKind> cAccessorKind, BridgedArrayRef cParams) {
std::optional<AccessorKind> accessorKind;
if (cAccessorKind)
accessorKind = unbridged(*cAccessorKind);
SmallVector<ParsedAutoDiffParameter, 2> params;
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
params.push_back(elem.unbridged());

return DerivativeAttr::create(cContext.unbridged(),
/*implicit=*/false, cAtLoc.unbridged(),
cRange.unbridged(), cBaseType.unbridged(),
DeclNameRefWithLoc{cOriginalName.unbridged(),
cOriginalNameLoc.unbridged(),
accessorKind},
params);
}

BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedAccessorKind cAccessorKind, BridgedArrayRef cParams) {
return BridgedDerivativeAttr_createParsedImpl(
cContext, cAtLoc, cRange, cBaseType, cOriginalName, cOriginalNameLoc,
cAccessorKind, cParams);
}

BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedArrayRef cParams) {
return BridgedDerivativeAttr_createParsedImpl(
cContext, cAtLoc, cRange, cBaseType, cOriginalName, cOriginalNameLoc,
/*cAccessorKind=*/std::nullopt, cParams);
}

BridgedDifferentiableAttr BridgedDifferentiableAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedDifferentiabilityKind cKind,
BridgedArrayRef cParams,
BridgedNullableTrailingWhereClause cGenericWhereClause) {
SmallVector<ParsedAutoDiffParameter, 2> params;
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
params.push_back(elem.unbridged());

return DifferentiableAttr::create(cContext.unbridged(), /*implicit=*/false,
cAtLoc.unbridged(), cRange.unbridged(),
unbridged(cKind), params,
cGenericWhereClause.unbridged());
}

BridgedDynamicReplacementAttr BridgedDynamicReplacementAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceLoc cAttrNameLoc, BridgedSourceLoc cLParenLoc,
Expand Down Expand Up @@ -752,6 +843,24 @@ BridgedSILGenNameAttr BridgedSILGenNameAttr_createParsed(
cRange.unbridged(), /*Implicit=*/false);
}

BridgedTransposeAttr BridgedTransposeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedArrayRef cParams) {
SmallVector<ParsedAutoDiffParameter, 2> params;
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
params.push_back(elem.unbridged());

return TransposeAttr::create(
cContext.unbridged(),
/*implicit=*/false, cAtLoc.unbridged(), cRange.unbridged(),
cBaseType.unbridged(),
DeclNameRefWithLoc{cOriginalName.unbridged(), cOriginalNameLoc.unbridged(),
/*AccessorKind=*/std::nullopt},
params);
}

BridgedUnavailableFromAsyncAttr BridgedUnavailableFromAsyncAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedStringRef cMessage) {
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Bridging/DeclBridging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static StaticSpellingKind unbridged(BridgedStaticSpelling kind) {
return static_cast<StaticSpellingKind>(kind);
}

static AccessorKind unbridged(BridgedAccessorKind kind) {
AccessorKind unbridged(BridgedAccessorKind kind) {
return static_cast<AccessorKind>(kind);
}

Expand Down
9 changes: 9 additions & 0 deletions lib/AST/Bridging/TypeAttributeBridging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ BridgedConventionTypeAttr BridgedConventionTypeAttr_createParsed(
{cClangType.unbridged(), cClangTypeLoc.unbridged()});
}

BridgedDifferentiableTypeAttr BridgedDifferentiableTypeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
BridgedDifferentiabilityKind cKind, BridgedSourceLoc cKindLoc) {
return new (cContext.unbridged()) DifferentiableTypeAttr(
cAtLoc.unbridged(), cNameLoc.unbridged(), cParensRange.unbridged(),
{unbridged(cKind), cKindLoc.unbridged()});
}

BridgedExecutionTypeAttr BridgedExecutionTypeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
Expand Down
Loading