Skip to content

Commit 8c17687

Browse files
authored
Merge pull request #29405 from dan-zheng/autodiff-upstream-sil-diff-param
[AutoDiff upstream] Add `@noDerivative` flag to `SILParameterInfo`.
2 parents 2d08a3f + a56e77a commit 8c17687

File tree

10 files changed

+212
-18
lines changed

10 files changed

+212
-18
lines changed

include/swift/AST/Types.h

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3637,13 +3637,35 @@ inline bool isGuaranteedParameter(ParameterConvention conv) {
36373637
llvm_unreachable("bad convention kind");
36383638
}
36393639

3640+
/// The differentiability of a SIL function type parameter.
3641+
enum class SILParameterDifferentiability : unsigned {
3642+
/// Either differentiable or not applicable.
3643+
///
3644+
/// - If the function type is not `@differentiable`, parameter
3645+
/// differentiability is not applicable. This case is the default value.
3646+
/// - If the function type is `@differentiable`, the function is
3647+
/// differentiable with respect to this parameter.
3648+
DifferentiableOrNotApplicable,
3649+
3650+
/// Not differentiable: a `@noDerivative` parameter.
3651+
///
3652+
/// May be applied only to parameters of `@differentiable` function types.
3653+
/// The function type is not differentiable with respect to this parameter.
3654+
NotDifferentiable,
3655+
};
3656+
36403657
/// A parameter type and the rules for passing it.
36413658
class SILParameterInfo {
36423659
llvm::PointerIntPair<CanType, 3, ParameterConvention> TypeAndConvention;
3660+
SILParameterDifferentiability Differentiability : 1;
3661+
36433662
public:
36443663
SILParameterInfo() = default;//: Ty(), Convention((ParameterConvention)0) {}
3645-
SILParameterInfo(CanType type, ParameterConvention conv)
3646-
: TypeAndConvention(type, conv) {
3664+
SILParameterInfo(
3665+
CanType type, ParameterConvention conv,
3666+
SILParameterDifferentiability differentiability =
3667+
SILParameterDifferentiability::DifferentiableOrNotApplicable)
3668+
: TypeAndConvention(type, conv), Differentiability(differentiability) {
36473669
assert(type->isLegalSILType() && "SILParameterInfo has illegal SIL type");
36483670
}
36493671

@@ -3698,6 +3720,16 @@ class SILParameterInfo {
36983720
return isGuaranteedParameter(getConvention());
36993721
}
37003722

3723+
SILParameterDifferentiability getDifferentiability() const {
3724+
return Differentiability;
3725+
}
3726+
3727+
SILParameterInfo getWithDifferentiability(
3728+
SILParameterDifferentiability differentiability) const {
3729+
return SILParameterInfo(getInterfaceType(), getConvention(),
3730+
differentiability);
3731+
}
3732+
37013733
/// The SIL storage type determines the ABI for arguments based purely on the
37023734
/// formal parameter conventions. The actual SIL type for the argument values
37033735
/// may differ in canonical SIL. In particular, opaque values require indirect
@@ -3726,6 +3758,7 @@ class SILParameterInfo {
37263758
void profile(llvm::FoldingSetNodeID &id) {
37273759
id.AddPointer(getInterfaceType().getPointer());
37283760
id.AddInteger((unsigned)getConvention());
3761+
id.AddInteger((unsigned)getDifferentiability());
37293762
}
37303763

37313764
SWIFT_DEBUG_DUMP;
@@ -3739,8 +3772,9 @@ class SILParameterInfo {
37393772
}
37403773

37413774
bool operator==(SILParameterInfo rhs) const {
3742-
return getInterfaceType() == rhs.getInterfaceType()
3743-
&& getConvention() == rhs.getConvention();
3775+
return getInterfaceType() == rhs.getInterfaceType() &&
3776+
getConvention() == rhs.getConvention() &&
3777+
getDifferentiability() == rhs.getDifferentiability();
37443778
}
37453779
bool operator!=(SILParameterInfo rhs) const {
37463780
return !(*this == rhs);
@@ -4093,6 +4127,13 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
40934127
return ExtInfo(NoEscape ? (Bits | NoEscapeMask) : (Bits & ~NoEscapeMask),
40944128
Other);
40954129
}
4130+
ExtInfo
4131+
withDifferentiabilityKind(DifferentiabilityKind differentiability) const {
4132+
return ExtInfo(
4133+
(Bits & ~DifferentiabilityMask) |
4134+
((unsigned)differentiability << DifferentiabilityMaskOffset),
4135+
Other);
4136+
}
40964137

40974138
std::pair<unsigned, const void *> getFuncAttrKey() const {
40984139
return std::make_pair(Bits, Other.ClangFunctionType);

lib/AST/ASTContext.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,6 +3327,15 @@ SILFunctionType::SILFunctionType(
33273327
"Cannot return an @noescape function type");
33283328
}
33293329
}
3330+
3331+
// Check that `@noDerivative` parameters only exist on `@differentiable`
3332+
// functions.
3333+
if (!ext.isDifferentiable())
3334+
for (auto param : getParameters())
3335+
assert(param.getDifferentiability() ==
3336+
SILParameterDifferentiability::DifferentiableOrNotApplicable &&
3337+
"non-`@differentiable` function should not have NotDifferentiable "
3338+
"parameter");
33303339
#endif
33313340
}
33323341

lib/AST/ASTPrinter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4649,6 +4649,13 @@ void SILParameterInfo::print(raw_ostream &OS, const PrintOptions &Opts) const {
46494649
}
46504650
void SILParameterInfo::print(ASTPrinter &Printer,
46514651
const PrintOptions &Opts) const {
4652+
switch (getDifferentiability()) {
4653+
case SILParameterDifferentiability::NotDifferentiable:
4654+
Printer << "@noDerivative ";
4655+
break;
4656+
default:
4657+
break;
4658+
}
46524659
Printer << getStringForParameterConvention(getConvention());
46534660
getInterfaceType().print(Printer, Opts);
46544661
}

lib/SIL/SILFunctionType.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,8 @@ class DestructureInputs {
945945
auto eltPattern = origType.getFunctionParamType(i);
946946
auto flags = params[i].getParameterFlags();
947947

948-
visit(flags.getValueOwnership(), /*forSelf=*/false,
949-
eltPattern, ty, silRepresentation);
948+
visit(flags.getValueOwnership(), /*forSelf=*/false, eltPattern, ty,
949+
silRepresentation, flags.isNoDerivative());
950950
}
951951

952952
// Process the self parameter. Note that we implicitly drop self
@@ -967,7 +967,8 @@ class DestructureInputs {
967967

968968
void visit(ValueOwnership ownership, bool forSelf,
969969
AbstractionPattern origType, CanType substType,
970-
SILFunctionTypeRepresentation rep) {
970+
SILFunctionTypeRepresentation rep,
971+
bool isNonDifferentiable = false) {
971972
assert(!isa<InOutType>(substType));
972973

973974
// Tuples get handled specially, in some cases:
@@ -1020,9 +1021,12 @@ class DestructureInputs {
10201021
substTLConv);
10211022
assert(!isIndirectFormalParameter(convention));
10221023
}
1023-
1024-
Inputs.push_back(SILParameterInfo(
1025-
substTL.getLoweredType().getASTType(), convention));
1024+
1025+
SILParameterInfo param(substTL.getLoweredType().getASTType(), convention);
1026+
if (isNonDifferentiable)
1027+
param = param.getWithDifferentiability(
1028+
SILParameterDifferentiability::NotDifferentiable);
1029+
Inputs.push_back(param);
10261030

10271031
maybeAddForeignParameters();
10281032
}
@@ -1460,7 +1464,8 @@ static CanSILFunctionType getSILFunctionType(
14601464
auto silExtInfo = SILFunctionType::ExtInfo()
14611465
.withRepresentation(extInfo.getSILRepresentation())
14621466
.withIsPseudogeneric(pseudogeneric)
1463-
.withNoEscape(extInfo.isNoEscape());
1467+
.withNoEscape(extInfo.isNoEscape())
1468+
.withDifferentiabilityKind(extInfo.getDifferentiabilityKind());
14641469

14651470
// Build the substituted generic signature we extracted.
14661471
bool impliedSignature = false;
@@ -2925,7 +2930,7 @@ class SILTypeSubstituter :
29252930

29262931
SILParameterInfo substInterface(SILParameterInfo orig) {
29272932
return SILParameterInfo(visit(orig.getInterfaceType()),
2928-
orig.getConvention());
2933+
orig.getConvention(), orig.getDifferentiability());
29292934
}
29302935

29312936
/// Tuples need to have their component types substituted by these

lib/Sema/TypeCheckType.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2948,6 +2948,8 @@ SILParameterInfo TypeResolver::resolveSILParameter(
29482948
auto convention = DefaultParameterConvention;
29492949
Type type;
29502950
bool hadError = false;
2951+
auto differentiability =
2952+
SILParameterDifferentiability::DifferentiableOrNotApplicable;
29512953

29522954
if (auto attrRepr = dyn_cast<AttributedTypeRepr>(repr)) {
29532955
auto attrs = attrRepr->getAttrs();
@@ -2973,6 +2975,10 @@ SILParameterInfo TypeResolver::resolveSILParameter(
29732975
checkFor(TypeAttrKind::TAK_owned, ParameterConvention::Direct_Owned);
29742976
checkFor(TypeAttrKind::TAK_guaranteed,
29752977
ParameterConvention::Direct_Guaranteed);
2978+
if (attrs.has(TAK_noDerivative)) {
2979+
attrs.clearAttribute(TAK_noDerivative);
2980+
differentiability = SILParameterDifferentiability::NotDifferentiable;
2981+
}
29762982

29772983
type = resolveAttributedType(attrs, attrRepr->getTypeRepr(), options);
29782984
} else {
@@ -2989,7 +2995,8 @@ SILParameterInfo TypeResolver::resolveSILParameter(
29892995
}
29902996

29912997
if (hadError) type = ErrorType::get(Context);
2992-
return SILParameterInfo(type->getCanonicalType(), convention);
2998+
return SILParameterInfo(type->getCanonicalType(), convention,
2999+
differentiability);
29933000
}
29943001

29953002
bool TypeResolver::resolveSingleSILResult(TypeRepr *repr,

lib/Serialization/Deserialization.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4503,6 +4503,21 @@ Optional<swift::ParameterConvention> getActualParameterConvention(uint8_t raw) {
45034503
return None;
45044504
}
45054505

4506+
/// Translate from the serialization SILParameterDifferentiability enumerators,
4507+
/// which are guaranteed to be stable, to the AST ones.
4508+
static Optional<swift::SILParameterDifferentiability>
4509+
getActualSILParameterDifferentiability(uint8_t raw) {
4510+
switch (serialization::SILParameterDifferentiability(raw)) {
4511+
#define CASE(ID) \
4512+
case serialization::SILParameterDifferentiability::ID: \
4513+
return swift::SILParameterDifferentiability::ID;
4514+
CASE(DifferentiableOrNotApplicable)
4515+
CASE(NotDifferentiable)
4516+
#undef CASE
4517+
}
4518+
return None;
4519+
}
4520+
45064521
/// Translate from the serialization ResultConvention enumerators,
45074522
/// which are guaranteed to be stable, to the AST ones.
45084523
static
@@ -5144,15 +5159,26 @@ class TypeDeserializer {
51445159
if (!calleeConvention.hasValue())
51455160
MF.fatal();
51465161

5147-
auto processParameter = [&](TypeID typeID, uint64_t rawConvention)
5148-
-> llvm::Expected<SILParameterInfo> {
5162+
auto processParameter =
5163+
[&](TypeID typeID, uint64_t rawConvention,
5164+
uint64_t ramDifferentiability) -> llvm::Expected<SILParameterInfo> {
51495165
auto convention = getActualParameterConvention(rawConvention);
51505166
if (!convention)
51515167
MF.fatal();
51525168
auto type = MF.getTypeChecked(typeID);
51535169
if (!type)
51545170
return type.takeError();
5155-
return SILParameterInfo(type.get()->getCanonicalType(), *convention);
5171+
auto differentiability =
5172+
swift::SILParameterDifferentiability::DifferentiableOrNotApplicable;
5173+
if (diffKind != DifferentiabilityKind::NonDifferentiable) {
5174+
auto differentiabilityOpt =
5175+
getActualSILParameterDifferentiability(ramDifferentiability);
5176+
if (!differentiabilityOpt)
5177+
MF.fatal();
5178+
differentiability = *differentiabilityOpt;
5179+
}
5180+
return SILParameterInfo(type.get()->getCanonicalType(), *convention,
5181+
differentiability);
51565182
};
51575183

51585184
auto processYield = [&](TypeID typeID, uint64_t rawConvention)
@@ -5191,7 +5217,10 @@ class TypeDeserializer {
51915217
for (unsigned i = 0; i != numParams; ++i) {
51925218
auto typeID = variableData[nextVariableDataIndex++];
51935219
auto rawConvention = variableData[nextVariableDataIndex++];
5194-
auto param = processParameter(typeID, rawConvention);
5220+
uint64_t differentiability = 0;
5221+
if (diffKind != DifferentiabilityKind::NonDifferentiable)
5222+
differentiability = variableData[nextVariableDataIndex++];
5223+
auto param = processParameter(typeID, rawConvention, differentiability);
51955224
if (!param)
51965225
return param.takeError();
51975226
allParams.push_back(param.get());

lib/Serialization/ModuleFormat.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t SWIFTMODULE_VERSION_MINOR = 533; // removed @_implicitly_synthesizes_nested_requirement
58+
const uint16_t SWIFTMODULE_VERSION_MINOR = 534; // add SIL parameter differentiability
5959

6060
/// A standard hash seed used for all string hashes in a serialized module.
6161
///
@@ -347,6 +347,13 @@ enum class ParameterConvention : uint8_t {
347347
};
348348
using ParameterConventionField = BCFixed<4>;
349349

350+
// These IDs must \em not be renumbered or reordered without incrementing
351+
// the module version.
352+
enum class SILParameterDifferentiability : uint8_t {
353+
DifferentiableOrNotApplicable,
354+
NotDifferentiable,
355+
};
356+
350357
// These IDs must \em not be renumbered or reordered without incrementing
351358
// the module version.
352359
enum class ResultConvention : uint8_t {

lib/Serialization/Serialization.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3766,6 +3766,17 @@ static uint8_t getRawStableParameterConvention(swift::ParameterConvention pc) {
37663766
llvm_unreachable("bad parameter convention kind");
37673767
}
37683768

3769+
/// Translate from AST SILParameterDifferentiability enum to the Serialization
3770+
/// enum values, which are guaranteed to be stable.
3771+
static uint8_t
3772+
getRawSILParameterDifferentiability(swift::SILParameterDifferentiability pd) {
3773+
switch (pd) {
3774+
SIMPLE_CASE(SILParameterDifferentiability, DifferentiableOrNotApplicable)
3775+
SIMPLE_CASE(SILParameterDifferentiability, NotDifferentiable)
3776+
}
3777+
llvm_unreachable("bad parameter differentiability kind");
3778+
}
3779+
37693780
/// Translate from the AST ResultConvention enum to the
37703781
/// Serialization enum values, which are guaranteed to be stable.
37713782
static uint8_t getRawStableResultConvention(swift::ResultConvention rc) {
@@ -4075,6 +4086,9 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
40754086
variableData.push_back(S.addTypeRef(param.getInterfaceType()));
40764087
unsigned conv = getRawStableParameterConvention(param.getConvention());
40774088
variableData.push_back(TypeID(conv));
4089+
if (fnTy->isDifferentiable())
4090+
variableData.push_back(TypeID(
4091+
getRawSILParameterDifferentiability(param.getDifferentiability())));
40784092
}
40794093
for (auto yield : fnTy->getYields()) {
40804094
variableData.push_back(S.addTypeRef(yield.getInterfaceType()));

test/AutoDiff/SIL/Serialization/differentiation.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,23 @@ bb0(%0 : $@differentiable(linear) (Float) -> Float):
2626
// CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float):
2727
// CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float
2828
// CHECK: }
29+
30+
sil @c : $@convention(thin) (@differentiable (Float, @noDerivative Float) -> Float) -> @differentiable (Float, @noDerivative Float) -> Float {
31+
bb0(%0 : $@differentiable (Float, @noDerivative Float) -> Float):
32+
return %0 : $@differentiable (Float, @noDerivative Float) -> Float
33+
}
34+
35+
// CHECK-LABEL: sil @c : $@convention(thin) (@differentiable (Float, @noDerivative Float) -> Float) -> @differentiable (Float, @noDerivative Float) -> Float {
36+
// CHECK: bb0(%0 : $@differentiable (Float, @noDerivative Float) -> Float):
37+
// CHECK: return %0 : $@differentiable (Float, @noDerivative Float) -> Float
38+
// CHECK: }
39+
40+
sil @d : $@convention(thin) (@differentiable(linear) (Float, @noDerivative Float) -> Float) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
41+
bb0(%0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float):
42+
return %0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float
43+
}
44+
45+
// CHECK-LABEL: sil @d : $@convention(thin) (@differentiable(linear) (Float, @noDerivative Float) -> Float) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
46+
// CHECK: bb0(%0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float):
47+
// CHECK: return %0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float
48+
// CHECK: }
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: %target-swift-frontend -emit-silgen -enable-experimental-differentiable-programming %s | %FileCheck %s
2+
3+
// Test SILGen for `@differentiable` function typed values.
4+
5+
import _Differentiation
6+
7+
@_silgen_name("differentiable")
8+
func differentiable(_ fn: @escaping @differentiable (Float) -> Float)
9+
-> @differentiable (Float) -> Float {
10+
return fn
11+
}
12+
13+
@_silgen_name("linear")
14+
func linear(_ fn: @escaping @differentiable(linear) (Float) -> Float)
15+
-> @differentiable(linear) (Float) -> Float {
16+
return fn
17+
}
18+
19+
@_silgen_name("differentiable_noDerivative")
20+
func differentiable_noDerivative(
21+
_ fn: @escaping @differentiable (Float, @noDerivative Float) -> Float
22+
) -> @differentiable (Float, @noDerivative Float) -> Float {
23+
return fn
24+
}
25+
26+
@_silgen_name("linear_noDerivative")
27+
func linear_noDerivative(
28+
_ fn: @escaping @differentiable(linear) (Float, @noDerivative Float) -> Float
29+
) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
30+
return fn
31+
}
32+
33+
// CHECK-LABEL: sil hidden [ossa] @differentiable : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> @owned @differentiable @callee_guaranteed (Float) -> Float {
34+
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float):
35+
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable @callee_guaranteed (Float) -> Float
36+
// CHECK: return [[COPIED_FN]] : $@differentiable @callee_guaranteed (Float) -> Float
37+
// CHECK: }
38+
39+
// CHECK-LABEL: sil hidden [ossa] @linear : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float) -> Float) -> @owned @differentiable(linear) @callee_guaranteed (Float) -> Float {
40+
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float) -> Float):
41+
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
42+
// CHECK: return [[COPIED_FN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
43+
// CHECK: }
44+
45+
// CHECK-LABEL: sil hidden [ossa] @differentiable_noDerivative : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float {
46+
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float):
47+
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float
48+
// CHECK: return [[COPIED_FN]] : $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float
49+
// CHECK: }
50+
51+
// CHECK-LABEL: sil hidden [ossa] @linear_noDerivative : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float {
52+
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float):
53+
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float
54+
// CHECK: return [[COPIED_FN]] : $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float
55+
// CHECK: }

0 commit comments

Comments
 (0)