Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion compiler/rustc_builtin_macros/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ builtin_macros_assert_requires_expression = macro requires an expression as an a

builtin_macros_autodiff = autodiff must be applied to function
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
Expand Down
67 changes: 47 additions & 20 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,23 @@ mod llvm_enzyme {
ecx: &mut ExtCtxt<'_>,
meta_item: &ThinVec<MetaItemInner>,
has_ret: bool,
mode: DiffMode,
) -> AutoDiffAttrs {
let dcx = ecx.sess.dcx();
let mode = name(&meta_item[1]);
let Ok(mode) = DiffMode::from_str(&mode) else {
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
return AutoDiffAttrs::error();
};

// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
let mut first_activity = 2;
let mut first_activity = 1;

let width = if let [_, _, x, ..] = &meta_item[..]
let width = if let [_, x, ..] = &meta_item[..]
&& let Some(x) = width(x)
{
first_activity = 3;
first_activity = 2;
match x.try_into() {
Ok(x) => x,
Err(_) => {
dcx.emit_err(errors::AutoDiffInvalidWidth {
span: meta_item[2].span(),
span: meta_item[1].span(),
width: x,
});
return AutoDiffAttrs::error();
Expand Down Expand Up @@ -165,6 +161,24 @@ mod llvm_enzyme {
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
}

pub(crate) fn expand_forward(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
item: Annotatable,
) -> Vec<Annotatable> {
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
}

pub(crate) fn expand_reverse(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
item: Annotatable,
) -> Vec<Annotatable> {
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
}

/// We expand the autodiff macro to generate a new placeholder function which passes
/// type-checking and can be called by users. The function body of the placeholder function will
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
Expand Down Expand Up @@ -198,11 +212,12 @@ mod llvm_enzyme {
/// ```
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
/// in CI.
pub(crate) fn expand(
pub(crate) fn expand_with_mode(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
mut item: Annotatable,
mode: DiffMode,
) -> Vec<Annotatable> {
if cfg!(not(llvm_enzyme)) {
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
Expand Down Expand Up @@ -245,29 +260,41 @@ mod llvm_enzyme {
// create TokenStream from vec elemtents:
// meta_item doesn't have a .tokens field
let mut ts: Vec<TokenTree> = vec![];
if meta_item_vec.len() < 2 {
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
// input and output args.
if meta_item_vec.len() < 1 {
// At the bare minimum, we need a fnc name.
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
return vec![item];
}

meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
let mode_symbol = match mode {
DiffMode::Forward => sym::Forward,
DiffMode::Reverse => sym::Reverse,
_ => unreachable!("Unsupported mode: {:?}", mode),
};

// Insert mode token
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
ts.insert(
1,
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
);

// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
// If it is not given, we default to 1 (scalar mode).
let start_position;
let kind: LitKind = LitKind::Integer;
let symbol;
if meta_item_vec.len() >= 3
&& let Some(width) = width(&meta_item_vec[2])
if meta_item_vec.len() >= 2
&& let Some(width) = width(&meta_item_vec[1])
{
start_position = 3;
start_position = 2;
symbol = Symbol::intern(&width.to_string());
} else {
start_position = 2;
start_position = 1;
symbol = sym::integer(1);
}

let l: Lit = Lit { kind, symbol, suffix: None };
let t = Token::new(TokenKind::Literal(l), Span::default());
let comma = Token::new(TokenKind::Comma, Span::default());
Expand All @@ -289,7 +316,7 @@ mod llvm_enzyme {
ts.pop();
let ts: TokenStream = TokenStream::from_iter(ts);

let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
if !x.is_active() {
// We encountered an error, so we return the original item.
// This allows us to potentially parse other attributes.
Expand Down Expand Up @@ -1017,4 +1044,4 @@ mod llvm_enzyme {
}
}

pub(crate) use llvm_enzyme::expand;
pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};
8 changes: 0 additions & 8 deletions compiler/rustc_builtin_macros/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,6 @@ mod autodiff {
pub(crate) act: String,
}

#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_mode)]
pub(crate) struct AutoDiffInvalidMode {
#[primary_span]
pub(crate) span: Span,
pub(crate) mode: String,
}

#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_width)]
pub(crate) struct AutoDiffInvalidWidth {
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_builtin_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
#![allow(internal_features)]
#![allow(rustc::diagnostic_outside_of_impl)]
#![allow(rustc::untranslatable_diagnostic)]
#![cfg_attr(not(bootstrap), feature(autodiff))]
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
#![doc(rust_logo)]
#![feature(assert_matches)]
#![feature(autodiff)]
#![feature(box_patterns)]
#![feature(decl_macro)]
#![feature(if_let_guard)]
Expand Down Expand Up @@ -112,7 +112,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {

register_attr! {
alloc_error_handler: alloc_error_handler::expand,
autodiff: autodiff::expand,
autodiff_forward: autodiff::expand_forward,
autodiff_reverse: autodiff::expand_reverse,
bench: test::expand_bench,
cfg_accessible: cfg_accessible::Expander,
cfg_eval: cfg_eval::expand,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_passes/src/check_attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
self.check_generic_attr(hir_id, attr, target, Target::Fn);
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
}
[sym::autodiff, ..] => {
[sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
self.check_autodiff(hir_id, attr, span, target)
}
[sym::coroutine, ..] => {
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ symbols! {
FnMut,
FnOnce,
Formatter,
Forward,
From,
FromIterator,
FromResidual,
Expand Down Expand Up @@ -348,6 +349,7 @@ symbols! {
Result,
ResumeTy,
Return,
Reverse,
Right,
Rust,
RustaceansAreAwesome,
Expand Down Expand Up @@ -531,7 +533,8 @@ symbols! {
audit_that,
augmented_assignments,
auto_traits,
autodiff,
autodiff_forward,
autodiff_reverse,
automatically_derived,
avx,
avx10_target_feature,
Expand Down
3 changes: 2 additions & 1 deletion library/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,11 @@ pub mod assert_matches {

// We don't export this through #[macro_export] for now, to avoid breakage.
#[unstable(feature = "autodiff", issue = "124509")]
#[cfg(not(bootstrap))]
/// Unstable module containing the unstable `autodiff` macro.
pub mod autodiff {
#[unstable(feature = "autodiff", issue = "124509")]
pub use crate::macros::builtin::autodiff;
pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};
}

#[unstable(feature = "contracts", issue = "128044")]
Expand Down
40 changes: 30 additions & 10 deletions library/core/src/macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1519,20 +1519,40 @@ pub(crate) mod builtin {
($file:expr $(,)?) => {{ /* compiler built-in */ }};
}

/// the derivative of a given function in the forward mode of differentiation.
/// It may only be applied to a function.
///
/// The expected usage syntax is:
/// `#[autodiff_forward(NAME, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
///
/// - `NAME`: A string that represents a valid function name.
/// - `INPUT_ACTIVITIES`: Specifies one valid activity for each input parameter.
/// - `OUTPUT_ACTIVITY`: Must not be set if the function implicitly returns nothing
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[rustc_builtin_macro]
#[cfg(not(bootstrap))]
pub macro autodiff_forward($item:item) {
/* compiler built-in */
}

/// Automatic Differentiation macro which allows generating a new function to compute
/// the derivative of a given function. It may only be applied to a function.
/// The expected usage syntax is
/// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
/// where:
/// NAME is a string that represents a valid function name.
/// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst.
/// INPUT_ACTIVITIES consists of one valid activity for each input parameter.
/// OUTPUT_ACTIVITY must not be set if we implicitly return nothing (or explicitly return
/// `-> ()`). Otherwise it must be set to one of the allowed activities.
/// the derivative of a given function in the reverse mode of differentiation.
/// It may only be applied to a function.
///
/// The expected usage syntax is:
/// `#[autodiff_reverse(NAME, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
///
/// - `NAME`: A string that represents a valid function name.
/// - `INPUT_ACTIVITIES`: Specifies one valid activity for each input parameter.
/// - `OUTPUT_ACTIVITY`: Must not be set if the function implicitly returns nothing
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[rustc_builtin_macro]
pub macro autodiff($item:item) {
#[cfg(not(bootstrap))]
pub macro autodiff_reverse($item:item) {
/* compiler built-in */
}

Expand Down
7 changes: 5 additions & 2 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,12 @@
// tidy-alphabetical-start

// stabilization was reverted after it hit beta
#![cfg_attr(not(bootstrap), feature(autodiff))]
#![feature(alloc_error_handler)]
#![feature(allocator_internals)]
#![feature(allow_internal_unsafe)]
#![feature(allow_internal_unstable)]
#![feature(asm_experimental_arch)]
#![feature(autodiff)]
#![feature(cfg_sanitizer_cfi)]
#![feature(cfg_target_thread_local)]
#![feature(cfi_encoding)]
Expand Down Expand Up @@ -636,12 +636,15 @@ pub mod simd {
#[doc(inline)]
pub use crate::std_float::StdFloat;
}

#[unstable(feature = "autodiff", issue = "124509")]
#[cfg(not(bootstrap))]
/// This module provides support for automatic differentiation.
pub mod autodiff {
/// This macro handles automatic differentiation.
pub use core::autodiff::autodiff;
pub use core::autodiff::{autodiff_forward, autodiff_reverse};
}

#[stable(feature = "futures_api", since = "1.36.0")]
pub mod task {
//! Types and Traits for working with asynchronous tasks.
Expand Down
8 changes: 4 additions & 4 deletions tests/codegen/autodiff/batched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_forward;

#[autodiff(d_square3, Forward, Dual, DualOnly)]
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
#[autodiff(d_square1, Forward, 4, Dual, Dual)]
#[autodiff_forward(d_square3, Dual, DualOnly)]
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
#[autodiff_forward(d_square1, 4, Dual, Dual)]
#[no_mangle]
fn square(x: &f32) -> f32 {
x * x
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/autodiff/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
//@ needs-enzyme
#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[autodiff(d_square, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square, Duplicated, Active)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
*x * *x
}
Expand Down
6 changes: 3 additions & 3 deletions tests/codegen/autodiff/identical_fnc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[autodiff(d_square, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}

#[autodiff(d_square2, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square2, Duplicated, Active)]
fn square2(x: &f64) -> f64 {
x * x
}
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/autodiff/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[autodiff(d_square, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/autodiff/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
//@ needs-enzyme
#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[autodiff(d_square, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn square(x: &f64) -> f64 {
x * x
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/autodiff/sret.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[no_mangle]
#[autodiff(df, Reverse, Active, Active, Active)]
#[autodiff_reverse(df, Active, Active, Active)]
fn primal(x: f32, y: f32) -> f64 {
(x * x * y) as f64
}
Expand Down
Loading
Loading