Skip to content

add sret handling for scalar autodiff #139465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ pub struct AutoDiffAttrs {
pub input_activity: Vec<DiffActivity>,
}

impl AutoDiffAttrs {
pub fn has_primal_ret(&self) -> bool {
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
}
}

impl DiffMode {
pub fn is_rev(&self) -> bool {
matches!(self, DiffMode::Reverse)
Expand Down
24 changes: 22 additions & 2 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
}

if attrs.width == 1 {
todo!("Handle sret for scalar ad");
// Enzyme returns a struct of style:
// `{ original_ret(if requested), float, float, ... }`
let mut struct_elements = vec![];
if attrs.has_primal_ret() {
struct_elements.push(inner_ret_ty);
}
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
// and therefore part of the return struct.
let param_tys = cx.func_params_types(fn_ty);
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
if matches!(act, DiffActivity::Active) {
// Now find the float type at position i based on the fn_ty,
// to know what (f16/f32/f64/...) to add to the struct.
struct_elements.push(param_ty);
}
}
ret_ty = cx.type_struct(&struct_elements, false);
} else {
// First we check if we also have to deal with the primal return.
match attrs.mode {
Expand Down Expand Up @@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
// now store the result of the enzyme call into the sret pointer.
let sret_ptr = outer_args[0];
let call_ty = cx.val_ty(call);
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
if attrs.width == 1 {
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
} else {
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
}
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
}
builder.ret_void();
Expand Down
File renamed without changes.
File renamed without changes.
45 changes: 45 additions & 0 deletions tests/codegen/autodiff/sret.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme

// This test is almost identical to the scalar.rs one,
// but we intentionally add a few more floats.
// `df` would ret `{ f64, f32, f32 }`, but is lowered as an sret.
// We therefore use this test to verify some of our sret handling.

#![feature(autodiff)]

use std::autodiff::autodiff;

#[no_mangle]
#[autodiff(df, Reverse, Active, Active, Active)]
fn primal(x: f32, y: f32) -> f64 {
(x * x * y) as f64
}

// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
// CHECK-NEXT:start:
// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0
// CHECK-NEXT: store double %.elt, ptr %_0, align 8
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
// CHECK-NEXT: ret void
// CHECK-NEXT:}

fn main() {
let x = std::hint::black_box(3.0);
let y = std::hint::black_box(2.5);
let scalar = std::hint::black_box(1.0);
let (r1, r2, r3) = df(x, y, scalar);
// 3*3*1.5 = 22.5
assert_eq!(r1, 22.5);
// 2*x*y = 2*3*2.5 = 15.0
assert_eq!(r2, 15.0);
// x*x*1 = 3*3 = 9
assert_eq!(r3, 9.0);
}
Loading