@@ -26,6 +26,16 @@ mod llvm_enzyme {
2626
2727    use  crate :: errors; 
2828
29+     pub ( crate )  fn  outer_normal_attr ( 
30+         kind :  & P < rustc_ast:: NormalAttr > , 
31+         id :  rustc_ast:: AttrId , 
32+         span :  Span , 
33+     )  -> rustc_ast:: Attribute  { 
34+         let  style = rustc_ast:: AttrStyle :: Outer ; 
35+         let  kind = rustc_ast:: AttrKind :: Normal ( kind. clone ( ) ) ; 
36+         rustc_ast:: Attribute  {  kind,  id,  style,  span } 
37+     } 
38+ 
2939    // If we have a default `()` return type or explicitley `()` return type, 
3040    // then we often can skip doing some work. 
3141    fn  has_ret ( ty :  & FnRetTy )  -> bool  { 
@@ -224,20 +234,8 @@ mod llvm_enzyme {
224234            . filter ( |a| * * a == DiffActivity :: Active  || * * a == DiffActivity :: ActiveOnly ) 
225235            . count ( )  as  u32 ; 
226236        let  ( d_sig,  new_args,  idents,  errored)  = gen_enzyme_decl ( ecx,  & sig,  & x,  span) ; 
227-         let  new_decl_span = d_sig. span ; 
228237        let  d_body = gen_enzyme_body ( 
229-             ecx, 
230-             & x, 
231-             n_active, 
232-             & sig, 
233-             & d_sig, 
234-             primal, 
235-             & new_args, 
236-             span, 
237-             sig_span, 
238-             new_decl_span, 
239-             idents, 
240-             errored, 
238+             ecx,  & x,  n_active,  & sig,  & d_sig,  primal,  & new_args,  span,  sig_span,  idents,  errored, 
241239        ) ; 
242240        let  d_ident = first_ident ( & meta_item_vec[ 0 ] ) ; 
243241
@@ -270,36 +268,39 @@ mod llvm_enzyme {
270268        } ; 
271269        let  inline_never_attr = P ( ast:: NormalAttr  {  item :  inline_item,  tokens :  None  } ) ; 
272270        let  new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ; 
273-         let  attr:  ast:: Attribute  = ast:: Attribute  { 
274-             kind :  ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) , 
275-             id :  new_id, 
276-             style :  ast:: AttrStyle :: Outer , 
277-             span, 
278-         } ; 
271+         let  attr = outer_normal_attr ( & rustc_ad_attr,  new_id,  span) ; 
279272        let  new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ; 
280-         let  inline_never:  ast:: Attribute  = ast:: Attribute  { 
281-             kind :  ast:: AttrKind :: Normal ( inline_never_attr) , 
282-             id :  new_id, 
283-             style :  ast:: AttrStyle :: Outer , 
284-             span, 
285-         } ; 
273+         let  inline_never = outer_normal_attr ( & inline_never_attr,  new_id,  span) ; 
274+ 
275+         // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`. 
276+         fn  same_attribute ( attr :  & ast:: AttrKind ,  item :  & ast:: AttrKind )  -> bool  { 
277+             match  ( attr,  item)  { 
278+                 ( ast:: AttrKind :: Normal ( a) ,  ast:: AttrKind :: Normal ( b) )  => { 
279+                     let  a = & a. item . path ; 
280+                     let  b = & b. item . path ; 
281+                     a. segments . len ( )  == b. segments . len ( ) 
282+                         && a. segments . iter ( ) . zip ( b. segments . iter ( ) ) . all ( |( a,  b) | a. ident  == b. ident ) 
283+                 } 
284+                 _ => false , 
285+             } 
286+         } 
286287
287288        // Don't add it multiple times: 
288289        let  orig_annotatable:  Annotatable  = match  item { 
289290            Annotatable :: Item ( ref  mut  iitem)  => { 
290-                 if  !iitem. attrs . iter ( ) . any ( |a| a . id  ==  attr. id )  { 
291+                 if  !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind ,   & attr. kind ) )  { 
291292                    iitem. attrs . push ( attr) ; 
292293                } 
293-                 if  !iitem. attrs . iter ( ) . any ( |a| a . id  ==  inline_never. id )  { 
294+                 if  !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind ,   & inline_never. kind ) )  { 
294295                    iitem. attrs . push ( inline_never. clone ( ) ) ; 
295296                } 
296297                Annotatable :: Item ( iitem. clone ( ) ) 
297298            } 
298299            Annotatable :: AssocItem ( ref  mut  assoc_item,  i @ Impl )  => { 
299-                 if  !assoc_item. attrs . iter ( ) . any ( |a| a . id  ==  attr. id )  { 
300+                 if  !assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind ,   & attr. kind ) )  { 
300301                    assoc_item. attrs . push ( attr) ; 
301302                } 
302-                 if  !assoc_item. attrs . iter ( ) . any ( |a| a . id  ==  inline_never. id )  { 
303+                 if  !assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind ,   & inline_never. kind ) )  { 
303304                    assoc_item. attrs . push ( inline_never. clone ( ) ) ; 
304305                } 
305306                Annotatable :: AssocItem ( assoc_item. clone ( ) ,  i) 
@@ -314,13 +315,7 @@ mod llvm_enzyme {
314315            delim :  rustc_ast:: token:: Delimiter :: Parenthesis , 
315316            tokens :  ts, 
316317        } ) ; 
317-         let  d_attr:  ast:: Attribute  = ast:: Attribute  { 
318-             kind :  ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) , 
319-             id :  new_id, 
320-             style :  ast:: AttrStyle :: Outer , 
321-             span, 
322-         } ; 
323- 
318+         let  d_attr = outer_normal_attr ( & rustc_ad_attr,  new_id,  span) ; 
324319        let  d_annotatable = if  is_impl { 
325320            let  assoc_item:  AssocItemKind  = ast:: AssocItemKind :: Fn ( asdf) ; 
326321            let  d_fn = P ( ast:: AssocItem  { 
@@ -361,30 +356,27 @@ mod llvm_enzyme {
361356        ty
362357    } 
363358
364-     /// We only want this function to type-check, since we will replace the body 
365-      /// later on llvm level. Using `loop {}` does not cover all return types anymore, 
366-      /// so instead we build something that should pass. We also add a inline_asm 
367-      /// line, as one more barrier for rustc to prevent inlining of this function. 
368-      /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see 
369-      /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient. 
370-      /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate 
371-      /// this function (which should never happen, since it is only a placeholder). 
372-      /// Finally, we also add back_box usages of all input arguments, to prevent rustc 
373-      /// from optimizing any arguments away. 
374-      fn  gen_enzyme_body ( 
359+     // Will generate a body of the type: 
360+     // ``` 
361+     // { 
362+     //   unsafe { 
363+     //   asm!("NOP"); 
364+     //   } 
365+     //   ::core::hint::black_box(primal(args)); 
366+     //   ::core::hint::black_box((args, ret)); 
367+     //   <This part remains to be done by following function> 
368+     // } 
369+     // ``` 
370+     fn  init_body_helper ( 
375371        ecx :  & ExtCtxt < ' _ > , 
376-         x :  & AutoDiffAttrs , 
377-         n_active :  u32 , 
378-         sig :  & ast:: FnSig , 
379-         d_sig :  & ast:: FnSig , 
372+         span :  Span , 
380373        primal :  Ident , 
381374        new_names :  & [ String ] , 
382-         span :  Span , 
383375        sig_span :  Span , 
384376        new_decl_span :  Span , 
385-         idents :  Vec < Ident > , 
377+         idents :  & [ Ident ] , 
386378        errored :  bool , 
387-     )  -> P < ast:: Block >  { 
379+     )  -> ( P < ast:: Block > ,   P < ast :: Expr > ,   P < ast :: Expr > ,   P < ast :: Expr > )  { 
388380        let  blackbox_path = ecx. std_path ( & [ sym:: hint,  sym:: black_box] ) ; 
389381        let  noop = ast:: InlineAsm  { 
390382            asm_macro :  ast:: AsmMacro :: Asm , 
@@ -433,6 +425,51 @@ mod llvm_enzyme {
433425        } 
434426        body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ; 
435427
428+         ( body,  primal_call,  black_box_primal_call,  blackbox_call_expr) 
429+     } 
430+ 
431+     /// We only want this function to type-check, since we will replace the body 
432+      /// later on llvm level. Using `loop {}` does not cover all return types anymore, 
433+      /// so instead we manually build something that should pass the type checker. 
434+      /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining 
435+      /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another 
436+      /// bug would ever try to accidentially differentiate this placeholder function body. 
437+      /// Finally, we also add back_box usages of all input arguments, to prevent rustc 
438+      /// from optimizing any arguments away. 
439+      fn  gen_enzyme_body ( 
440+         ecx :  & ExtCtxt < ' _ > , 
441+         x :  & AutoDiffAttrs , 
442+         n_active :  u32 , 
443+         sig :  & ast:: FnSig , 
444+         d_sig :  & ast:: FnSig , 
445+         primal :  Ident , 
446+         new_names :  & [ String ] , 
447+         span :  Span , 
448+         sig_span :  Span , 
449+         idents :  Vec < Ident > , 
450+         errored :  bool , 
451+     )  -> P < ast:: Block >  { 
452+         let  new_decl_span = d_sig. span ; 
453+ 
454+         // Just adding some default inline-asm and black_box usages to prevent early inlining 
455+         // and optimizations which alter the function signature. 
456+         // 
457+         // The bb_primal_call is the black_box call of the primal function. We keep it around, 
458+         // since it has the convenient property of returning the type of the primal function, 
459+         // Remember, we only care to match types here. 
460+         // No matter which return we pick, we always wrap it into a std::hint::black_box call, 
461+         // to prevent rustc from propagating it into the caller. 
462+         let  ( mut  body,  primal_call,  bb_primal_call,  bb_call_expr)  = init_body_helper ( 
463+             ecx, 
464+             span, 
465+             primal, 
466+             new_names, 
467+             sig_span, 
468+             new_decl_span, 
469+             & idents, 
470+             errored, 
471+         ) ; 
472+ 
436473        if  !has_ret ( & d_sig. decl . output )  { 
437474            // there is no return type that we have to match, () works fine. 
438475            return  body; 
@@ -444,7 +481,7 @@ mod llvm_enzyme {
444481
445482        if  primal_ret && n_active == 0  && x. mode . is_rev ( )  { 
446483            // We only have the primal ret. 
447-             body. stmts . push ( ecx. stmt_expr ( black_box_primal_call ) ) ; 
484+             body. stmts . push ( ecx. stmt_expr ( bb_primal_call ) ) ; 
448485            return  body; 
449486        } 
450487
@@ -536,11 +573,11 @@ mod llvm_enzyme {
536573                return  body; 
537574            } 
538575            [ arg]  => { 
539-                 ret = ecx. expr_call ( new_decl_span,  blackbox_call_expr ,  thin_vec ! [ arg. clone( ) ] ) ; 
576+                 ret = ecx. expr_call ( new_decl_span,  bb_call_expr ,  thin_vec ! [ arg. clone( ) ] ) ; 
540577            } 
541578            args => { 
542579                let  ret_tuple:  P < ast:: Expr >  = ecx. expr_tuple ( span,  args. into ( ) ) ; 
543-                 ret = ecx. expr_call ( new_decl_span,  blackbox_call_expr ,  thin_vec ! [ ret_tuple] ) ; 
580+                 ret = ecx. expr_call ( new_decl_span,  bb_call_expr ,  thin_vec ! [ ret_tuple] ) ; 
544581            } 
545582        } 
546583        assert ! ( has_ret( & d_sig. decl. output) ) ; 
@@ -553,7 +590,7 @@ mod llvm_enzyme {
553590        ecx :  & ExtCtxt < ' _ > , 
554591        span :  Span , 
555592        primal :  Ident , 
556-         idents :  Vec < Ident > , 
593+         idents :  & [ Ident ] , 
557594    )  -> P < ast:: Expr >  { 
558595        let  has_self = idents. len ( )  > 0  && idents[ 0 ] . name  == kw:: SelfLower ; 
559596        if  has_self { 
0 commit comments