Skip to content

Commit 8dc69e7

Browse files
committed
Add tests
1 parent f3d7f30 commit 8dc69e7

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation -o /dev/null 2>&1 %s | %FileCheck %s
2+
3+
// The differentiability witness for y in s(h:) will be generated by silgen. However, later the capture
4+
// promotion pass would specialize it since it only captures an integer and therefore does not need to
5+
// box the capture. Ensure we create differentiability witness for specialized function. In addition to
6+
// this, since the original function is not used anymore, the body of it is removed (with only unreachable
7+
// terminator inside). Remove original differentiability witness as it would lead to non-differentiable
8+
// diagnostics further on.
9+
10+
// CHECK-LABEL: differentiability witness for specialized y #1 (_:) in s(h:)
11+
// CHECK: sil_differentiability_witness private [reverse] [parameters 0] [results 0] @$s4null1s1hAA1BVAE_tF1yL_yAA1WVAHFTf2ni_n : $@convention(thin) (@guaranteed W, Int) -> @owned W {
12+
// CHECK-NOT: sil_differentiability_witness private [reverse] [parameters 0] [results 0] @$s4null1s1hAA1BVAE_tF1yL_yAA1WVAHF : $@convention(thin) (@guaranteed W, @guaranteed { var Int }) -> @owned W {
13+
14+
import _Differentiation
15+
struct B: Differentiable{}
16+
struct X { var j = [Float]()}
17+
struct W: Differentiable {
18+
@noDerivative var z: X
19+
var h: B
20+
}
21+
func o<T, R>(_ x: T, _ f: @differentiable(reverse) (T) -> R) -> R {f(x)}
22+
func m<T, R>(_ f: @escaping @differentiable(reverse) (T) -> R) -> @differentiable(reverse) (T) -> R {{ x in o(x, f) }}
23+
@differentiable(reverse)
24+
func s(h: B) -> B {
25+
var (_, e) = (0,0)
26+
@differentiable(reverse)
27+
func y(_ i: W) -> W {
28+
let _ = e;
29+
return i
30+
}
31+
let w = m(y)
32+
return B()
33+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation -emit-module -module-name M -emit-module-path %t/M.swiftmodule 2>&1 %s | %FileCheck %s
3+
4+
// The original function Tensor.subscriptIndexPath() is not marked as @differentiable. As a result, no explicit differentiable witness is generated for it.
5+
// However, the witness is generated as a side effect of providing a derivative via @derivative(of: subscriptIndexPath) on _vjpSubscriptIndexPath.
6+
// Since _vjpSubscriptIndexPath is not emitted when -emit-module is used, we need to ensure we still generate a wittness.
7+
8+
import _Differentiation
9+
10+
// CHECK-LABEL: differentiability witness for Tensor.subscriptIndexPath()
11+
// CHECK: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s1M6TensorV18subscriptIndexPathACyF : $@convention(method) (Tensor) -> Tensor {
12+
// CHECK: vjp: @$s1M6TensorV18subscriptIndexPathACyFTJrSpSr : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor)
13+
14+
// CHECK-LABEL: reverse-mode derivative of Tensor.subscriptIndexPath()
15+
// CHECK: sil [thunk] [always_inline] [ossa] @$s1M6TensorV18subscriptIndexPathACyFTJrSpSr : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor) {
16+
// CHECK: function_ref Tensor._vjpSubscriptIndexPath()
17+
// CHECK: function_ref @$s1M6TensorV22_vjpSubscriptIndexPathAC5value_A2Cc8pullbacktyF : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor)
18+
19+
public struct Tensor: Differentiable & AdditiveArithmetic {
20+
@inlinable
21+
func subscriptIndexPath() -> Tensor {
22+
fatalError()
23+
}
24+
25+
@inlinable
26+
@differentiable(reverse, wrt: self)
27+
func subscriptRanges() -> Tensor {
28+
subscriptIndexPath()
29+
}
30+
31+
@usableFromInline
32+
@derivative(of: subscriptIndexPath)
33+
func _vjpSubscriptIndexPath() -> (
34+
value: Tensor, pullback: (Tensor) -> Tensor
35+
) {
36+
fatalError()
37+
}
38+
}

0 commit comments

Comments
 (0)