Skip to content

Commit 8c38065

Browse files
committed
[WIP] TF-1021: verify that retrodiff works.
Added a simple four file test: - A: defines original function. - B: registers derivative for original function. - C: also registers derivative for original function. - D: imports A, B, and C, and differentiates original function.
1 parent 979914e commit 8c38065

File tree

4 files changed

+27
-0
lines changed

4 files changed

+27
-0
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3906,6 +3906,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
39063906
return;
39073907
}
39083908

3909+
#if 0
39093910
// Reject different-file retroactive derivatives.
39103911
// TODO(TF-136): Lift this restriction now that SIL differentiability witness
39113912
// infrastructure is ready.
@@ -3914,6 +3915,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
39143915
diag::derivative_attr_not_in_same_file_as_original);
39153916
return;
39163917
}
3918+
#endif
39173919

39183920
// Valid `@derivative` attributes are uniqued by original function and
39193921
// parameter indices. Reject duplicate attributes.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import derivative_registration_original_module
2+
3+
@derivative(of: id)
4+
func vjpId(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
5+
return (id(x), { $0 })
6+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
public func id(_ x: Float) -> Float { x }
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Verify that cross-file derivative registration works.
2+
3+
// RUN: %empty-directory(%t)
4+
// RUN: %target-swift-frontend -emit-module -primary-file %S/../Inputs/derivative_registration_original_module.swift -emit-module-path %t/derivative_registration_original_module.swiftmodule
5+
// RUN: %target-swift-frontend -I %t -emit-module -primary-file %S/../Inputs/derivative_registration_derivative_module.swift -emit-module-path %t/derivative_registration_derivative_module.swiftmodule
6+
// RUN: %target-swift-frontend -I %t -emit-module -primary-file %S/../Inputs/derivative_registration_derivative_module2.swift -emit-module-path %t/derivative_registration_derivative_module2.swiftmodule
7+
// RUN: %target-swift-emit-sil -I %t -emit-module %s
8+
9+
import derivative_registration_original_module
10+
import derivative_registration_derivative_module
11+
import derivative_registration_derivative_module2
12+
13+
func x(_ x: Float) -> Float {
14+
// TODO: Check which derivative for `id` is used:
15+
// - The one from `derivative_registration_derivative_module`, or
16+
// - The one from `derivative_registration_derivative_module2`.
17+
return gradient(at: x, in: id)
18+
}

0 commit comments

Comments
 (0)