Skip to content

Commit 4f16d77

Browse files
committed
[WIP] TF-1021: verify that retrodiff works.
Added a simple three file test: - A: defines original function. - B: registers derivative. - C: imports A and B, differentiates original function.
1 parent 979914e commit 4f16d77

File tree

4 files changed

+22
-0
lines changed

4 files changed

+22
-0
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3906,6 +3906,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
39063906
return;
39073907
}
39083908

3909+
#if 0
39093910
// Reject different-file retroactive derivatives.
39103911
// TODO(TF-136): Lift this restriction now that SIL differentiability witness
39113912
// infrastructure is ready.
@@ -3914,6 +3915,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
39143915
diag::derivative_attr_not_in_same_file_as_original);
39153916
return;
39163917
}
3918+
#endif
39173919

39183920
// Valid `@derivative` attributes are uniqued by original function and
39193921
// parameter indices. Reject duplicate attributes.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import derivative_registration_original_module
2+
3+
@derivative(of: id)
4+
func vjpId(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
5+
return (id(x), { $0 })
6+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
public func id(_ x: Float) -> Float { x }
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Verify that cross-file derivative registration works.
2+
3+
// RUN: %empty-directory(%t)
4+
// RUN: %target-swift-frontend -emit-module -primary-file %S/../Inputs/derivative_registration_original_module.swift -emit-module-path %t/derivative_registration_original_module.swiftmodule
5+
// RUN: %target-swift-frontend -I %t -emit-module -primary-file %S/../Inputs/derivative_registration_derivative_module.swift -emit-module-path %t/derivative_registration_derivative_module.swiftmodule
6+
// RUN: %target-swift-emit-sil -I %t -emit-module %s
7+
8+
import derivative_registration_original_module
9+
import derivative_registration_derivative_module
10+
11+
func x(_ x: Float) -> Float {
12+
return gradient(at: x, in: id)
13+
}

0 commit comments

Comments
 (0)