Skip to content

Commit 1e403ec

Browse files
authored
[AutoDiff] Support custom derivatives for @_alwaysEmitIntoClient functions (#78908)
Consider an `@_alwaysEmitIntoClient` function and a custom derivative defined for it. Previously, such a combination resulted different errors under different circumstances. Sometimes, there were linker errors due to missing derivative function symbol - these occurred when we tried to find the derivative in a module, while it should have been emitted into client's code (and it did not happen). Sometimes, there were SIL verification failures like this: ``` SIL verification failed: internal/private function cannot be serialized or serializable: !F->isAnySerialized() || embedded ``` Linkage and serialization options for the derivative were not handled properly, and, instead of PublicNonABI linkage, we had Private one which is unsupported for serialization - but we need to serialize `@_alwaysEmitIntoClient` functions so the client's code is able to see them. This patch resolves the issue and adds proper handling of custom derivatives of `@_alwaysEmitIntoClient` functions. Note that either both the function and its custom derivative or none of them should have `@_alwaysEmitIntoClient` attribute, mismatch in this attribute is not supported. The following cases are handled (assume that in each case client's code uses the derivative). 1. Both the function and its derivative are defined in a single file in one module. 2. Both the function and its derivative are defined in different files which are compiled to a single module. 3. The function is defined in one module, its derivative is defined in another module. 4. The function and the derivative are defined as members of a protocol extension in two separate modules - one for the function and one for the derivative. A struct conforming the protocol is defined in the third module. 5. The function and the derivative are defined as members of a struct extension in two separate modules - one for the function and one for the derivative. The changes allow to define derivatives for methods of `SIMD`. Fixes #54445 <!-- If this pull request is targeting a release branch, please fill out the following form: https://github.com/swiftlang/.github/blob/main/PULL_REQUEST_TEMPLATE/release.md?plain=1 Otherwise, replace this comment with a description of your changes and rationale. Provide links to external references/discussions if appropriate. If this pull request resolves any GitHub issues, link them like so: Resolves <link to issue>, resolves <link to another issue>. For more information about linking a pull request to an issue, see: https://docs.github.com/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue --> <!-- Before merging this pull request, you must run the Swift continuous integration tests. For information about triggering CI builds via @swift-ci, see: https://github.com/apple/swift/blob/main/docs/ContinuousIntegration.md#swift-ci Thank you for your contribution to Swift! -->
1 parent 29a9fb0 commit 1e403ec

30 files changed

+476
-31
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4387,6 +4387,9 @@ NOTE(derivative_attr_fix_access,none,
43874387
"mark the derivative function as "
43884388
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
43894389
"to match the original function", (AccessLevel))
4390+
ERROR(derivative_attr_always_emit_into_client_mismatch,none,
4391+
"either both or none of derivative and original function must have "
4392+
"@alwaysEmitIntoClient attribute", ())
43904393
ERROR(derivative_attr_static_method_mismatch_original,none,
43914394
"unexpected derivative function declaration; "
43924395
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",

lib/SIL/IR/Linker.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,23 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
159159
// HiddenExternal linkage when they are declarations, then they
160160
// become Shared after the body has been deserialized.
161161
// So try deserializing HiddenExternal functions too.
162-
if (linkage == SILLinkage::HiddenExternal)
163-
return deserializeAndPushToWorklist(F);
164-
162+
if (linkage == SILLinkage::HiddenExternal) {
163+
deserializeAndPushToWorklist(F);
164+
if (!F->markedAsAlwaysEmitIntoClient())
165+
return;
166+
// For @_alwaysEmitIntoClient functions, we need to lookup its
167+
// differentiability witness and, if present, ask SILLoader to obtain its
168+
// definition. Otherwise, a linker error would occur due to undefined
169+
// reference to these symbols.
170+
for (SILDifferentiabilityWitness *witness :
171+
F->getModule().lookUpDifferentiabilityWitnessesForFunction(
172+
F->getName())) {
173+
F->getModule().getSILLoader()->lookupDifferentiabilityWitness(
174+
witness->getKey());
175+
}
176+
return;
177+
}
178+
165179
// Update the linkage of the function in case it's different in the serialized
166180
// SIL than derived from the AST. This can be the case with cross-module-
167181
// optimizations.

lib/SILGen/SILGen.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,14 +1435,19 @@ void SILGenModule::emitDifferentiabilityWitness(
14351435
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
14361436
if (!diffWitness) {
14371437
// Differentiability witnesses have the same linkage as the original
1438-
// function, stripping external.
1439-
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
1438+
// function, stripping external. For @_alwaysEmitIntoClient original
1439+
// functions, force PublicNonABI linkage of the differentiability witness so
1440+
// we can serialize it (the original function itself might be HiddenExternal
1441+
// in this case if we only have declaration without definition).
1442+
auto linkage =
1443+
originalFunction->markedAsAlwaysEmitIntoClient()
1444+
? SILLinkage::PublicNonABI
1445+
: stripExternalFromLinkage(originalFunction->getLinkage());
14401446
diffWitness = SILDifferentiabilityWitness::createDefinition(
14411447
M, linkage, originalFunction, diffKind, silConfig.parameterIndices,
14421448
silConfig.resultIndices, config.derivativeGenericSignature,
14431449
/*jvp*/ nullptr, /*vjp*/ nullptr,
1444-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
1445-
attr);
1450+
/*isSerialized*/ hasPublicVisibility(linkage), attr);
14461451
}
14471452

14481453
// Set derivative function in differentiability witness.

lib/SILGen/SILGenPoly.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6498,8 +6498,14 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
64986498
auto loc = customDerivativeFn->getLocation();
64996499
SILGenFunctionBuilder fb(*this);
65006500
// Derivative thunks have the same linkage as the original function, stripping
6501-
// external.
6502-
auto linkage = stripExternalFromLinkage(originalFn->getLinkage());
6501+
// external. For @_alwaysEmitIntoClient original functions, force PublicNonABI
6502+
// linkage of derivative thunks so we can serialize them (the original
6503+
// function itself might be HiddenExternal in this case if we only have
6504+
// declaration without definition).
6505+
auto linkage = originalFn->markedAsAlwaysEmitIntoClient()
6506+
? SILLinkage::PublicNonABI
6507+
: stripExternalFromLinkage(originalFn->getLinkage());
6508+
65036509
auto *thunk = fb.getOrCreateFunction(
65046510
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent,
65056511
customDerivativeFn->getSerializedKind(),

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
538538
"definitions with explicit differentiable attributes");
539539

540540
return SILDifferentiabilityWitness::createDeclaration(
541-
module, SILLinkage::PublicExternal, original, kind,
542-
minimalConfig->parameterIndices, minimalConfig->resultIndices,
543-
minimalConfig->derivativeGenericSignature);
541+
module,
542+
// Witness for @_alwaysEmitIntoClient original function must be emitted,
543+
// otherwise a linker error would occur due to undefined reference to the
544+
// witness symbol.
545+
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
546+
: SILLinkage::PublicExternal,
547+
original, kind, minimalConfig->parameterIndices,
548+
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
544549
}
545550

546551
} // end namespace autodiff

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -999,10 +999,14 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
999999

10001000
// We can generate empty JVP / VJP for functions available externally. These
10011001
// functions have the same linkage as the original ones sans `external`
1002-
// flag. Important exception here hidden_external functions as they are
1003-
// serializable but corresponding hidden ones would be not and the SIL
1004-
// verifier will fail. Patch `serializeFunctions` for this case.
1005-
if (orig->getLinkage() == SILLinkage::HiddenExternal)
1002+
// flag. Important exception here hidden_external non-@_alwaysEmitIntoClient
1003+
// functions as they are serializable but corresponding hidden ones would be
1004+
// not and the SIL verifier will fail. Patch `serializeFunctions` for this
1005+
// case. For @_alwaysEmitIntoClient original functions (which might be
1006+
// HiddenExternal if we only have declaration without definition), we want
1007+
// derivatives to be serialized and do not patch `serializeFunctions`.
1008+
if (orig->getLinkage() == SILLinkage::HiddenExternal &&
1009+
!orig->markedAsAlwaysEmitIntoClient())
10061010
serializeFunctions = IsNotSerialized;
10071011

10081012
// If the JVP doesn't exist, need to synthesize it.

lib/Sema/TypeCheckAttr.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6990,6 +6990,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
69906990
return true;
69916991
}
69926992

6993+
if (originalAFD->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>() !=
6994+
derivative->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>()) {
6995+
diags.diagnose(derivative->getLoc(),
6996+
diag::derivative_attr_always_emit_into_client_mismatch);
6997+
return true;
6998+
}
6999+
69937000
// Get the resolved differentiability parameter indices.
69947001
auto *resolvedDiffParamIndices = attr->getParameterIndices();
69957002

stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,6 @@ where
405405
}
406406
}
407407

408-
// FIXME(TF-1103): Derivative registration does not yet support
409-
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
410-
/*
411408
extension SIMD
412409
where
413410
Self: Differentiable,
@@ -417,6 +414,7 @@ where
417414
TangentVector == Self
418415
{
419416
@inlinable
417+
@_alwaysEmitIntoClient
420418
@derivative(of: sum)
421419
func _vjpSum() -> (
422420
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
@@ -425,14 +423,14 @@ where
425423
}
426424

427425
@inlinable
426+
@_alwaysEmitIntoClient
428427
@derivative(of: sum)
429428
func _jvpSum() -> (
430429
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
431430
) {
432431
return (sum(), { v in Scalar.TangentVector(v.sum()) })
433432
}
434433
}
435-
*/
436434

437435
extension SIMD
438436
where

test/AutoDiff/SILGen/nil_coalescing.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s
1+
/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions`
2+
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s
23

34
import _Differentiation
45

5-
// CHECK: sil @test_nil_coalescing
6+
// CHECK: sil non_abi @test_nil_coalescing
67
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
78
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
89
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
@@ -15,7 +16,7 @@ import _Differentiation
1516
//
1617
@_silgen_name("test_nil_coalescing")
1718
@derivative(of: ??)
18-
@usableFromInline
19+
@_alwaysEmitIntoClient
1920
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
2021
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
2122
{

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb
10621062
fatalError()
10631063
}
10641064

1065+
func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1066+
@_alwaysEmitIntoClient
1067+
@derivative(of: internal_original_alwaysemitintoclient_derivative_error)
1068+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1069+
func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1070+
fatalError()
1071+
}
1072+
1073+
@_alwaysEmitIntoClient
10651074
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10661075
@_alwaysEmitIntoClient
10671076
@derivative(of: internal_original_alwaysemitintoclient_derivative)
@@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float
10841093
fatalError()
10851094
}
10861095

1096+
@_alwaysEmitIntoClient
1097+
package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1098+
@derivative(of: package_original_alwaysemitintoclient_derivative_error)
1099+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1100+
package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1101+
fatalError()
1102+
}
1103+
1104+
@_alwaysEmitIntoClient
10871105
package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10881106
@_alwaysEmitIntoClient
10891107
@derivative(of: package_original_alwaysemitintoclient_derivative)

test/AutoDiff/stdlib/simd.swift

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ SIMDTests.test("init(repeating:)") {
1919
expectEqual(8, pb1(g))
2020
}
2121

22-
// FIXME(TF-1103): Derivative registration does not yet support
23-
// `@_alwaysEmitIntoClient` original functions.
24-
/*
2522
SIMDTests.test("Sum") {
2623
let a = SIMD4<Float>(1, 2, 3, 4)
2724

@@ -32,7 +29,6 @@ SIMDTests.test("Sum") {
3229
expectEqual(10, val1)
3330
expectEqual(SIMD4<Float>(3, 3, 3, 3), pb1(3))
3431
}
35-
*/
3632

3733
SIMDTests.test("Identity") {
3834
let a = SIMD4<Float>(1, 2, 3, 4)
@@ -289,9 +285,6 @@ SIMDTests.test("Generics") {
289285
expectEqual(SIMD3<Double>(5, 10, 15), val4)
290286
expectEqual((SIMD3<Double>(5, 5, 5), 6), pb4(g))
291287

292-
// FIXME(TF-1103): Derivative registration does not yet support
293-
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
294-
/*
295288
func testSum<Scalar, SIMDType: SIMD>(x: SIMDType) -> Scalar
296289
where SIMDType.Scalar == Scalar,
297290
SIMDType : Differentiable,
@@ -304,7 +297,6 @@ SIMDTests.test("Generics") {
304297
let (val5, pb5) = valueWithPullback(at: a, of: simd3Sum)
305298
expectEqual(6, val5)
306299
expectEqual(SIMD3<Double>(7, 7, 7), pb5(7))
307-
*/
308300
}
309301

310302
runAllTests()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@_alwaysEmitIntoClient
2+
public func f(_ x: Float) -> Float {
3+
x
4+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import _Differentiation
2+
3+
@derivative(of: f)
4+
@_alwaysEmitIntoClient
5+
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
6+
(x, { 42 * $0 })
7+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@_alwaysEmitIntoClient
2+
public func f(_ x: Float) -> Float {
3+
x
4+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import MultiModule1
2+
import _Differentiation
3+
4+
@derivative(of: f)
5+
@_alwaysEmitIntoClient
6+
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
7+
(x, { 42 * $0 })
8+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import _Differentiation
2+
3+
public protocol Protocol {
4+
var x : Float {get set}
5+
init()
6+
}
7+
8+
extension Protocol {
9+
public init(_ val: Float) {
10+
self.init()
11+
x = val
12+
}
13+
14+
@_alwaysEmitIntoClient
15+
public func sum() -> Float { x }
16+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import MultiModuleProtocol1
2+
import _Differentiation
3+
4+
extension Protocol where Self: Differentiable, Self.TangentVector == Self {
5+
@_alwaysEmitIntoClient
6+
@derivative(of: sum)
7+
public func _vjpSum() -> (
8+
value: Float, pullback: (Float) -> Self.TangentVector
9+
) {
10+
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
11+
}
12+
13+
@_alwaysEmitIntoClient
14+
@derivative(of: sum)
15+
public func _jvpSum() -> (
16+
value: Float, differential: (Self.TangentVector) -> Float
17+
) {
18+
(value: self.x, differential: { 42 * $0.x })
19+
}
20+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import MultiModuleProtocol1
2+
import MultiModuleProtocol2
3+
import _Differentiation
4+
5+
public struct Struct : Protocol {
6+
private var _x : Float
7+
public var x : Float {
8+
get { _x }
9+
set { _x = newValue }
10+
}
11+
public init() { _x = 0 }
12+
}
13+
14+
extension Struct : AdditiveArithmetic {
15+
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
16+
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
17+
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
18+
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
19+
public static var zero: Self { Self(0) }
20+
}
21+
22+
extension Struct : Differentiable {
23+
public typealias TangentVector = Self
24+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
public struct Struct {
2+
public var x : Float
3+
public typealias TangentVector = Self
4+
public init() { x = 0 }
5+
}
6+
7+
extension Struct {
8+
public init(_ val: Float) {
9+
self.init()
10+
x = val
11+
}
12+
13+
@_alwaysEmitIntoClient
14+
public func sum() -> Float { x }
15+
}
16+
17+
extension Struct : AdditiveArithmetic {
18+
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
19+
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
20+
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
21+
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
22+
public static var zero: Self { Self(0) }
23+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import MultiModuleStruct1
2+
import _Differentiation
3+
4+
extension Struct : Differentiable {
5+
@_alwaysEmitIntoClient
6+
@derivative(of: sum)
7+
public func _vjpSum() -> (
8+
value: Float, pullback: (Float) -> Self.TangentVector
9+
) {
10+
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
11+
}
12+
13+
@_alwaysEmitIntoClient
14+
@derivative(of: sum)
15+
public func _jvpSum() -> (
16+
value: Float, differential: (Self.TangentVector) -> Float
17+
) {
18+
(value: self.x, differential: { 42 * $0.x })
19+
}
20+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import MultiModuleStruct1
2+
import _Differentiation
3+
4+
extension Struct : Differentiable {
5+
@_alwaysEmitIntoClient
6+
@derivative(of: sum)
7+
public func _vjpSum() -> (
8+
value: Float, pullback: (Float) -> Self.TangentVector
9+
) {
10+
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
11+
}
12+
}

0 commit comments

Comments
 (0)