Skip to content

Commit a165957

Browse files
committed
WIP typetree impl
1 parent 1b61d43 commit a165957

File tree

8 files changed

+451
-9
lines changed

8 files changed

+451
-9
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::str::FromStr;
99
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1010
use crate::ptr::P;
1111
use crate::{Ty, TyKind};
12+
use crate::expand::typetree::TypeTree;
1213

1314
/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
1415
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
@@ -85,6 +86,9 @@ pub struct AutoDiffItem {
8586
/// The name of the function being generated
8687
pub target: String,
8788
pub attrs: AutoDiffAttrs,
89+
// --- TypeTree support ---
90+
pub inputs: Vec<TypeTree>,
91+
pub output: TypeTree,
8892
}
8993

9094
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -112,6 +116,10 @@ impl AutoDiffAttrs {
112116
pub fn has_primal_ret(&self) -> bool {
113117
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
114118
}
119+
/// New constructor for type tree support
120+
pub fn into_item(self, source: String, target: String, inputs: Vec<TypeTree>, output: TypeTree) -> AutoDiffItem {
121+
AutoDiffItem { source, target, attrs: self, inputs, output }
122+
}
115123
}
116124

117125
impl DiffMode {
@@ -284,6 +292,8 @@ impl AutoDiffAttrs {
284292
impl fmt::Display for AutoDiffItem {
285293
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286294
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
287-
write!(f, " with attributes: {:?}", self.attrs)
295+
write!(f, " with attributes: {:?}", self.attrs)?;
296+
write!(f, " with inputs: {:?}", self.inputs)?;
297+
write!(f, " with output: {:?}", self.output)
288298
}
289299
}

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ mod llvm_enzyme {
1111
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
1212
valid_ty_for_activity,
1313
};
14+
use rustc_ast::expand::typetree::{TypeTree, Type, Kind};
1415
use rustc_ast::ptr::P;
16+
use crate::typetree::construct_typetree_from_fnsig;
1517
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
1618
use rustc_ast::tokenstream::*;
1719
use rustc_ast::visit::AssocCtxt::*;
@@ -324,6 +326,17 @@ mod llvm_enzyme {
324326
}
325327
let span = ecx.with_def_site_ctxt(expand_span);
326328

329+
// Construct real type trees from function signature
330+
let (inputs, output) = construct_typetree_from_fnsig(&sig);
331+
332+
// Use the new into_item method to construct the AutoDiffItem
333+
let autodiff_item = x.clone().into_item(
334+
primal.to_string(),
335+
first_ident(&meta_item_vec[0]).to_string(),
336+
inputs,
337+
output,
338+
);
339+
327340
let n_active: u32 = x
328341
.input_activity
329342
.iter()
@@ -1045,5 +1058,3 @@ mod llvm_enzyme {
10451058
(d_sig, new_inputs, idents, false)
10461059
}
10471060
}
1048-
1049-
pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};

compiler/rustc_builtin_macros/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ mod pattern_type;
5151
mod source_util;
5252
mod test;
5353
mod trace_macros;
54+
mod typetree;
5455

5556
pub mod asm;
5657
pub mod cmdline_attrs;

0 commit comments

Comments
 (0)