Skip to content

Commit 1035486

Browse files
committed
Add codegen tests
1 parent fce7d46 commit 1035486

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
// This does only test the funtion attribute handling for autodiff.
6+
// Function argument changes are troublesome for Enzyme, so we have to
7+
// ensure that arguments remain the same, or if we change them, be aware
8+
// of the changes to handle it correctly.
9+
10+
#![feature(autodiff)]
11+
12+
use std::autodiff::{autodiff_forward, autodiff_reverse};
13+
14+
#[derive(Copy, Clone)]
15+
struct Input {
16+
x: f32,
17+
y: f32,
18+
}
19+
20+
#[derive(Copy, Clone)]
21+
struct Wrapper {
22+
z: f32,
23+
}
24+
25+
#[derive(Copy, Clone)]
26+
struct NestedInput {
27+
x: f32,
28+
y: Wrapper,
29+
}
30+
31+
fn square(x: f32) -> f32 {
32+
x * x
33+
}
34+
35+
// CHECK: ; abi_handling::f1
36+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
37+
// CHECK-NEXT: define dso_local noundef float @_ZN12abi_handling2f1{{.*}}(ptr noalias nocapture noundef readonly align 4 dereferenceable(8) %x)
38+
#[autodiff_forward(df1, Dual, Dual)]
39+
fn f1(x: &[f32; 2]) -> f32 {
40+
x[0] + x[1]
41+
}
42+
43+
// CHECK: ; abi_handling::f2
44+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
45+
// CHECK-NEXT: define dso_local noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E(ptr nocapture noundef nonnull readonly %f, float noundef %x)
46+
#[autodiff_reverse(df2, Const, Active, Active)]
47+
fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
48+
f(x)
49+
}
50+
51+
// CHECK: ; abi_handling::f3
52+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
53+
// CHECK-NEXT: define dso_local noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E(ptr noalias nocapture noundef readonly align 4 dereferenceable(4) %x, ptr noalias nocapture noundef readonly align 4 dereferenceable(4) %y)
54+
#[autodiff_forward(df3, Dual, Dual, Dual)]
55+
fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 {
56+
*x * *y
57+
}
58+
59+
// CHECK: ; abi_handling::f4
60+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
61+
// CHECK-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float noundef %x.0, float noundef %x.1)
62+
#[autodiff_forward(df4, Dual, Dual)]
63+
fn f4(x: (f32, f32)) -> f32 {
64+
x.0 * x.1
65+
}
66+
67+
// CHECK: ; abi_handling::f5
68+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
69+
// CHECK-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float noundef %i.0, float noundef %i.1)
70+
#[autodiff_forward(df5, Dual, Dual)]
71+
fn f5(i: Input) -> f32 {
72+
i.x + i.y
73+
}
74+
75+
// CHECK: ; abi_handling::f6
76+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
77+
// CHECK-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE(float noundef %i.0, float noundef %i.1)
78+
#[autodiff_forward(df6, Dual, Dual)]
79+
fn f6(i: NestedInput) -> f32 {
80+
i.x + i.y.z * i.y.z
81+
}
82+
83+
fn main() {
84+
let x = std::hint::black_box(2.0);
85+
let y = std::hint::black_box(3.0);
86+
87+
let in_f1 = [x, y];
88+
dbg!(f1(&in_f1));
89+
let dx1 = std::hint::black_box(&[1.0, 0.0]);
90+
let res_f1 = df1(&in_f1, dx1);
91+
dbg!(res_f1);
92+
93+
dbg!(f2(square, x));
94+
let res_f2 = df2(square, x, 1.0);
95+
dbg!(res_f2);
96+
97+
static Y: f32 = std::hint::black_box(3.2);
98+
dbg!(f3(&x, &Y));
99+
let res_f3 = df3(&x, &Y, &1.0, &0.0);
100+
dbg!(res_f3);
101+
102+
let in_f4 = (x, y);
103+
dbg!(f4(in_f4));
104+
let res_f4 = df4(in_f4, (1.0, 0.0));
105+
dbg!(res_f4);
106+
107+
let in_f5 = Input { x, y };
108+
dbg!(f5(in_f5));
109+
let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 });
110+
dbg!(res_f5);
111+
112+
let in_f6 = NestedInput { x, y: Wrapper { z: y } };
113+
dbg!(f6(in_f6));
114+
let res_f6 = df6(in_f6, NestedInput { x: 1.0, y: Wrapper { z: 0.0 } });
115+
dbg!(res_f6);
116+
}

0 commit comments

Comments
 (0)