Skip to content

Commit f3c06f2

Browse files
committed
move old tests, add sret test
1 parent 5486059 commit f3c06f2

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed
File renamed without changes.
File renamed without changes.

tests/codegen/autodiff/sret.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
// This test is almost identical to the scalar.rs one,
6+
// but we intentionally add a few more floats.
7+
// `df` would ret `{ f64, f32, f32 }`, but is lowered as an sret.
8+
// We therefore use this test to verify some of our sret handling.
9+
10+
11+
#![feature(autodiff)]
12+
13+
use std::autodiff::autodiff;
14+
15+
#[no_mangle]
16+
#[autodiff(df, Reverse, Active, Active, Active)]
17+
fn primal(x: f32, y: f32) -> f64 {
18+
(x * x * y) as f64
19+
}
20+
21+
// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
22+
// CHECK-NEXT:start:
23+
// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
24+
// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0
25+
// CHECK-NEXT: store double %.elt, ptr %_0, align 8
26+
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
27+
// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1
28+
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8
29+
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
30+
// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2
31+
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
32+
// CHECK-NEXT: ret void
33+
// CHECK-NEXT:}
34+
35+
36+
fn main() {
37+
let x = std::hint::black_box(3.0);
38+
let y = std::hint::black_box(2.5);
39+
let scalar = std::hint::black_box(1.0);
40+
let (r1, r2, r3) = df(x, y, scalar);
41+
// 3*3*1.5 = 22.5
42+
assert_eq!(r1, 22.5);
43+
// 2*x*y = 2*3*2.5 = 15.0
44+
assert_eq!(r2, 15.0);
45+
// x*x*1 = 3*3 = 9
46+
assert_eq!(r3, 9.0);
47+
}

0 commit comments

Comments
 (0)