Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 22 additions & 94 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
//! we create an `RustcAutodiff` which contains the source and target function names. The source
//! is the function to which the autodiff attribute is applied, and the target is the function
//! getting generated by us (with a name given by the user as the first autodiff arg).

use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use rustc_span::{Symbol, sym};

use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::{Ty, TyKind};

Expand All @@ -31,6 +32,12 @@ pub enum DiffMode {
Reverse,
}

impl DiffMode {
pub fn all_modes() -> &'static [Symbol] {
&[sym::Source, sym::Forward, sym::Reverse]
}
}

/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
/// we add to the previous shadow value. To not surprise users, we picked different names.
Expand Down Expand Up @@ -76,43 +83,20 @@ impl DiffActivity {
use DiffActivity::*;
matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
}
}
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
/// The name of the function getting differentiated
pub source: String,
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
/// e.g. in the [JAX
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
pub mode: DiffMode,
/// A user-provided, batching width. If not given, we will default to 1 (no batching).
/// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
/// - Calling the function 50 times with a batch size of 2
/// - Calling the function 25 times with a batch size of 4,
/// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
/// cache locality, better re-usal of primal values, and other optimizations.
/// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
/// times, so this massively increases code size. As such, values like 1024 are unlikely to
/// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
/// experiments for now and focus on documenting the implications of a large width.
pub width: u32,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}

impl AutoDiffAttrs {
pub fn has_primal_ret(&self) -> bool {
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
pub fn all_activities() -> &'static [Symbol] {
&[
sym::None,
sym::Active,
sym::ActiveOnly,
sym::Const,
sym::Dual,
sym::Dualv,
sym::DualOnly,
sym::DualvOnly,
sym::Duplicated,
sym::DuplicatedOnly,
]
}
}

Expand Down Expand Up @@ -241,59 +225,3 @@ impl FromStr for DiffActivity {
}
}
}

impl AutoDiffAttrs {
pub fn has_ret_activity(&self) -> bool {
self.ret_activity != DiffActivity::None
}
pub fn has_active_only_ret(&self) -> bool {
self.ret_activity == DiffActivity::ActiveOnly
}

pub const fn error() -> Self {
AutoDiffAttrs {
mode: DiffMode::Error,
width: 0,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}
pub fn source() -> Self {
AutoDiffAttrs {
mode: DiffMode::Source,
width: 0,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}

pub fn is_active(&self) -> bool {
self.mode != DiffMode::Error
}

pub fn is_source(&self) -> bool {
self.mode == DiffMode::Source
}
pub fn apply_autodiff(&self) -> bool {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}

pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}
117 changes: 117 additions & 0 deletions compiler/rustc_attr_parsing/src/attributes/autodiff.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use std::str::FromStr;

use rustc_ast::LitKind;
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
use rustc_feature::{AttributeTemplate, template};
use rustc_hir::attrs::{AttributeKind, RustcAutodiff};
use rustc_hir::{MethodKind, Target};
use rustc_span::{Symbol, sym};
use thin_vec::ThinVec;

use crate::attributes::prelude::Allow;
use crate::attributes::{AttributeOrder, OnDuplicate, SingleAttributeParser};
use crate::context::{AcceptContext, Stage};
use crate::parser::{ArgParser, MetaItemOrLitParser};
use crate::target_checking::AllowedTargets;

pub(crate) struct RustcAutodiffParser;

impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
const PATH: &[Symbol] = &[sym::rustc_autodiff];
const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepInnermost;
const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Error;
const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[
Allow(Target::Fn),
Allow(Target::Method(MethodKind::Inherent)),
Allow(Target::Method(MethodKind::Trait { body: true })),
Allow(Target::Method(MethodKind::TraitImpl)),
]);
const TEMPLATE: AttributeTemplate = template!(
List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"],
"https://doc.rust-lang.org/std/autodiff/index.html"
);

fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option<AttributeKind> {
let list = match args {
ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)),
ArgParser::List(list) => list,
ArgParser::NameValue(_) => {
cx.expected_list_or_no_args(cx.attr_span);
return None;
}
};

let mut items = list.mixed().peekable();

// Parse name
let Some(mode) = items.next() else {
cx.expected_at_least_one_argument(list.span);
return None;
};
let Some(mode) = mode.meta_item() else {
cx.expected_identifier(mode.span());
return None;
};
let Ok(()) = mode.args().no_args() else {
cx.expected_identifier(mode.span());
return None;
};
let Some(mode) = mode.path().word() else {
cx.expected_identifier(mode.span());
return None;
};
let Ok(mode) = DiffMode::from_str(mode.as_str()) else {
cx.expected_specific_argument(mode.span, DiffMode::all_modes());
return None;
};

// Parse width
let width = if let Some(width) = items.peek()
&& let MetaItemOrLitParser::Lit(width) = width
&& let LitKind::Int(width, _) = width.kind
&& let Ok(width) = width.0.try_into()
{
_ = items.next();
width
} else {
1
};

// Parse activities
let mut activities = ThinVec::new();
for activity in items {
let MetaItemOrLitParser::MetaItemParser(activity) = activity else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Ok(()) = activity.args().no_args() else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Some(activity) = activity.path().word() else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Ok(activity) = DiffActivity::from_str(activity.as_str()) else {
cx.expected_specific_argument(activity.span, DiffActivity::all_activities());
return None;
};

activities.push(activity);
}
let Some(ret_activity) = activities.pop() else {
cx.expected_specific_argument(
list.span.with_lo(list.span.hi()),
DiffActivity::all_activities(),
);
return None;
};

Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff {
mode,
width,
input_activity: activities,
ret_activity,
}))))
}
}
1 change: 1 addition & 0 deletions compiler/rustc_attr_parsing/src/attributes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::target_checking::AllowedTargets;
mod prelude;

pub(crate) mod allow_unstable;
pub(crate) mod autodiff;
pub(crate) mod body;
pub(crate) mod cfg;
pub(crate) mod cfg_select;
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_attr_parsing/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use rustc_span::{ErrorGuaranteed, Span, Symbol};
use crate::AttributeParser;
// Glob imports to avoid big, bitrotty import lists
use crate::attributes::allow_unstable::*;
use crate::attributes::autodiff::*;
use crate::attributes::body::*;
use crate::attributes::cfi_encoding::*;
use crate::attributes::codegen_attrs::*;
Expand Down Expand Up @@ -202,6 +203,7 @@ attribute_parsers!(
Single<ReexportTestHarnessMainParser>,
Single<RustcAbiParser>,
Single<RustcAllocatorZeroedVariantParser>,
Single<RustcAutodiffParser>,
Single<RustcBuiltinMacroParser>,
Single<RustcDefPath>,
Single<RustcDeprecatedSafe2024Parser>,
Expand Down
18 changes: 9 additions & 9 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ mod llvm_enzyme {
use std::string::String;

use rustc_ast::expand::autodiff_attrs::{
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
valid_ty_for_activity,
DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity,
};
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
use rustc_ast::tokenstream::*;
Expand All @@ -20,6 +19,7 @@ mod llvm_enzyme {
MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_hir::attrs::RustcAutodiff;
use rustc_span::{Ident, Span, Symbol, sym};
use thin_vec::{ThinVec, thin_vec};
use tracing::{debug, trace};
Expand Down Expand Up @@ -87,7 +87,7 @@ mod llvm_enzyme {
meta_item: &ThinVec<MetaItemInner>,
has_ret: bool,
mode: DiffMode,
) -> AutoDiffAttrs {
) -> RustcAutodiff {
let dcx = ecx.sess.dcx();

// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
Expand All @@ -105,7 +105,7 @@ mod llvm_enzyme {
span: meta_item[1].span(),
width: x,
});
return AutoDiffAttrs::error();
return RustcAutodiff::error();
}
}
} else {
Expand All @@ -129,7 +129,7 @@ mod llvm_enzyme {
};
}
if errors {
return AutoDiffAttrs::error();
return RustcAutodiff::error();
}

// If a return type exist, we need to split the last activity,
Expand All @@ -145,11 +145,11 @@ mod llvm_enzyme {
(&DiffActivity::None, activities.as_slice())
};

AutoDiffAttrs {
RustcAutodiff {
mode,
width,
ret_activity: *ret_activity,
input_activity: input_activity.to_vec(),
input_activity: input_activity.iter().cloned().collect(),
}
}

Expand Down Expand Up @@ -309,7 +309,7 @@ mod llvm_enzyme {
ts.pop();
let ts: TokenStream = TokenStream::from_iter(ts);

let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
let x: RustcAutodiff = from_ast(ecx, &meta_item_vec, has_ret, mode);
if !x.is_active() {
// We encountered an error, so we return the original item.
// This allows us to potentially parse other attributes.
Expand Down Expand Up @@ -603,7 +603,7 @@ mod llvm_enzyme {
fn gen_enzyme_decl(
ecx: &ExtCtxt<'_>,
sig: &ast::FnSig,
x: &AutoDiffAttrs,
x: &RustcAutodiff,
span: Span,
) -> ast::FnSig {
let dcx = ecx.sess.dcx();
Expand Down
8 changes: 5 additions & 3 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::ptr;

use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
use rustc_ast::expand::typetree::FncTree;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_data_structures::thin_vec::ThinVec;
use rustc_hir::attrs::RustcAutodiff;
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
use rustc_middle::{bug, ty};
use rustc_target::callconv::PassMode;
Expand All @@ -18,7 +20,7 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
typing_env: TypingEnv<'tcx>,
da: &mut Vec<DiffActivity>,
da: &mut ThinVec<DiffActivity>,
) {
let fn_ty = instance.ty(tcx, typing_env);

Expand Down Expand Up @@ -295,7 +297,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
outer_name: &str,
ret_ty: &'ll Type,
fn_args: &[&'ll Value],
attrs: AutoDiffAttrs,
attrs: &RustcAutodiff,
dest: PlaceRef<'tcx, &'ll Value>,
fnc_tree: FncTree,
) {
Expand Down
Loading
Loading