Skip to content

Commit a56e77a

Browse files
committed
Lower AST @noDerivative attribute to SIL.
Add SILGen test.
1 parent ac48feb commit a56e77a

File tree

3 files changed

+75
-8
lines changed

3 files changed

+75
-8
lines changed

include/swift/AST/Types.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4127,6 +4127,13 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41274127
return ExtInfo(NoEscape ? (Bits | NoEscapeMask) : (Bits & ~NoEscapeMask),
41284128
Other);
41294129
}
4130+
ExtInfo
4131+
withDifferentiabilityKind(DifferentiabilityKind differentiability) const {
4132+
return ExtInfo(
4133+
(Bits & ~DifferentiabilityMask) |
4134+
((unsigned)differentiability << DifferentiabilityMaskOffset),
4135+
Other);
4136+
}
41304137

41314138
std::pair<unsigned, const void *> getFuncAttrKey() const {
41324139
return std::make_pair(Bits, Other.ClangFunctionType);

lib/SIL/SILFunctionType.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -754,8 +754,8 @@ class DestructureInputs {
754754
auto eltPattern = origType.getFunctionParamType(i);
755755
auto flags = params[i].getParameterFlags();
756756

757-
visit(flags.getValueOwnership(), /*forSelf=*/false,
758-
eltPattern, ty, silRepresentation);
757+
visit(flags.getValueOwnership(), /*forSelf=*/false, eltPattern, ty,
758+
silRepresentation, flags.isNoDerivative());
759759
}
760760

761761
// Process the self parameter. Note that we implicitly drop self
@@ -776,7 +776,8 @@ class DestructureInputs {
776776

777777
void visit(ValueOwnership ownership, bool forSelf,
778778
AbstractionPattern origType, CanType substType,
779-
SILFunctionTypeRepresentation rep) {
779+
SILFunctionTypeRepresentation rep,
780+
bool isNonDifferentiable = false) {
780781
assert(!isa<InOutType>(substType));
781782

782783
// Tuples get handled specially, in some cases:
@@ -829,9 +830,12 @@ class DestructureInputs {
829830
substTLConv);
830831
assert(!isIndirectFormalParameter(convention));
831832
}
832-
833-
Inputs.push_back(SILParameterInfo(
834-
substTL.getLoweredType().getASTType(), convention));
833+
834+
SILParameterInfo param(substTL.getLoweredType().getASTType(), convention);
835+
if (isNonDifferentiable)
836+
param = param.getWithDifferentiability(
837+
SILParameterDifferentiability::NotDifferentiable);
838+
Inputs.push_back(param);
835839

836840
maybeAddForeignParameters();
837841
}
@@ -1269,7 +1273,8 @@ static CanSILFunctionType getSILFunctionType(
12691273
auto silExtInfo = SILFunctionType::ExtInfo()
12701274
.withRepresentation(extInfo.getSILRepresentation())
12711275
.withIsPseudogeneric(pseudogeneric)
1272-
.withNoEscape(extInfo.isNoEscape());
1276+
.withNoEscape(extInfo.isNoEscape())
1277+
.withDifferentiabilityKind(extInfo.getDifferentiabilityKind());
12731278

12741279
// Build the substituted generic signature we extracted.
12751280
bool impliedSignature = false;
@@ -2734,7 +2739,7 @@ class SILTypeSubstituter :
27342739

27352740
SILParameterInfo substInterface(SILParameterInfo orig) {
27362741
return SILParameterInfo(visit(orig.getInterfaceType()),
2737-
orig.getConvention());
2742+
orig.getConvention(), orig.getDifferentiability());
27382743
}
27392744

27402745
/// Tuples need to have their component types substituted by these
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: %target-swift-frontend -emit-silgen -enable-experimental-differentiable-programming %s | %FileCheck %s
2+
3+
// Test SILGen for `@differentiable` function typed values.
4+
5+
import _Differentiation
6+
7+
@_silgen_name("differentiable")
8+
func differentiable(_ fn: @escaping @differentiable (Float) -> Float)
9+
-> @differentiable (Float) -> Float {
10+
return fn
11+
}
12+
13+
@_silgen_name("linear")
14+
func linear(_ fn: @escaping @differentiable(linear) (Float) -> Float)
15+
-> @differentiable(linear) (Float) -> Float {
16+
return fn
17+
}
18+
19+
@_silgen_name("differentiable_noDerivative")
20+
func differentiable_noDerivative(
21+
_ fn: @escaping @differentiable (Float, @noDerivative Float) -> Float
22+
) -> @differentiable (Float, @noDerivative Float) -> Float {
23+
return fn
24+
}
25+
26+
@_silgen_name("linear_noDerivative")
27+
func linear_noDerivative(
28+
_ fn: @escaping @differentiable(linear) (Float, @noDerivative Float) -> Float
29+
) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
30+
return fn
31+
}
32+
33+
// CHECK-LABEL: sil hidden [ossa] @differentiable : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> @owned @differentiable @callee_guaranteed (Float) -> Float {
34+
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float):
35+
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable @callee_guaranteed (Float) -> Float
36+
// CHECK: return [[COPIED_FN]] : $@differentiable @callee_guaranteed (Float) -> Float
37+
// CHECK: }
38+
39+
// CHECK-LABEL: sil hidden [ossa] @linear : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float) -> Float) -> @owned @differentiable(linear) @callee_guaranteed (Float) -> Float {
40+
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float) -> Float):
41+
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
42+
// CHECK: return [[COPIED_FN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
43+
// CHECK: }
44+
45+
// CHECK-LABEL: sil hidden [ossa] @differentiable_noDerivative : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float {
46+
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float):
47+
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float
48+
// CHECK: return [[COPIED_FN]] : $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float
49+
// CHECK: }
50+
51+
// CHECK-LABEL: sil hidden [ossa] @linear_noDerivative : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float {
52+
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float):
53+
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float
54+
// CHECK: return [[COPIED_FN]] : $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float
55+
// CHECK: }

0 commit comments

Comments
 (0)