11//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat 
22//@ no-prefer-dynamic 
33//@ needs-enzyme 
4+ 
45#![ feature( autodiff) ]  
56
67use  std:: autodiff:: autodiff; 
78
8- #[ autodiff( d_square,  Reverse ,  4 ,  Duplicated ,  Active ) ]  
9+ #[ autodiff( d_square3,  Forward ,  Dual ,  DualOnly ) ]  
10+ #[ no_mangle]  
11+ fn  squaref ( x :  & f32 )  -> f32  { 
12+     2.0  *  x *  x
13+ } 
14+ 
15+ 
16+ #[ autodiff( d_square2,  Forward ,  4 ,  Dual ,  DualOnly ) ]  
17+ #[ autodiff( d_square,  Forward ,  4 ,  Dual ,  Dual ) ]  
918#[ no_mangle]  
10- fn  square ( x :  & f64 )  -> f64  { 
19+ fn  square ( x :  & f32 )  -> f32  { 
1120    x *  x
1221} 
1322
@@ -33,21 +42,31 @@ fn square(x: &f64) -> f64 {
3342// CHECK-NEXT:} 
3443
3544fn  main ( )  { 
36-     let  x = 3.0 ; 
45+     let  x = std :: hint :: black_box ( 3.0 ) ; 
3746    let  output = square ( & x) ; 
47+     dbg ! ( & output) ; 
3848    assert_eq ! ( 9.0 ,  output) ; 
49+     dbg ! ( squaref( & x) ) ; 
3950
40-     let  mut  df_dx1 = 0 .0; 
41-     let  mut  df_dx2 = 0 .0; 
42-     let  mut  df_dx3 = 0 .0; 
51+     let  mut  df_dx1 = 1 .0; 
52+     let  mut  df_dx2 = 2 .0; 
53+     let  mut  df_dx3 = 3 .0; 
4354    let  mut  df_dx4 = 0.0 ; 
44-     let  [ o1,  o2,  o3,  o4]  = d_square ( & x,  & mut  df_dx1,  & mut  df_dx2,  & mut  df_dx3,  & mut  df_dx4,  1.0 ) ; 
45-     assert_eq ! ( output,  o1) ; 
46-     assert_eq ! ( output,  o2) ; 
47-     assert_eq ! ( output,  o3) ; 
48-     assert_eq ! ( output,  o4) ; 
49-     assert_eq ! ( 6.0 ,  df_dx1) ; 
50-     assert_eq ! ( 6.0 ,  df_dx2) ; 
51-     assert_eq ! ( 6.0 ,  df_dx3) ; 
52-     assert_eq ! ( 6.0 ,  df_dx4) ; 
55+     let  [ o1, o2, o3, o4]  = d_square2 ( & x,  & mut  df_dx1,  & mut  df_dx2,  & mut  df_dx3,   & mut  df_dx4) ; 
56+     dbg ! ( o1,  o2,  o3,  o4) ; 
57+     let  [ output2,  o1, o2, o3, o4]  = d_square ( & x,  & mut  df_dx1,  & mut  df_dx2,  & mut  df_dx3,   & mut  df_dx4) ; 
58+     dbg ! ( o1,  o2,  o3,  o4) ; 
59+     assert_eq ! ( output,  output2) ; 
60+     assert ! ( ( 6.0  - o1) . abs( )  < 1e-10 ) ; 
61+     assert ! ( ( 12.0  - o2) . abs( )  < 1e-10 ) ; 
62+     assert ! ( ( 18.0  - o3) . abs( )  < 1e-10 ) ; 
63+     assert ! ( ( 0.0  - o4) . abs( )  < 1e-10 ) ; 
64+     assert_eq ! ( 1.0 ,  df_dx1) ; 
65+     assert_eq ! ( 2.0 ,  df_dx2) ; 
66+     assert_eq ! ( 3.0 ,  df_dx3) ; 
67+     assert_eq ! ( 0.0 ,  df_dx4) ; 
68+     assert_eq ! ( d_square3( & x,  & mut  df_dx1) ,  2.0  *  o1) ; 
69+     assert_eq ! ( d_square3( & x,  & mut  df_dx2) ,  2.0  *  o2) ; 
70+     assert_eq ! ( d_square3( & x,  & mut  df_dx3) ,  2.0  *  o3) ; 
71+     assert_eq ! ( d_square3( & x,  & mut  df_dx4) ,  2.0  *  o4) ; 
5372} 
0 commit comments