Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {

fn visit_const_arg(&mut self, const_arg: &'hir ConstArg<'hir, AmbigArg>) {
self.insert(
const_arg.as_unambig_ct().span(),
const_arg.as_unambig_ct().span,
const_arg.hir_id,
Node::ConstArg(const_arg.as_unambig_ct()),
);
Expand Down
63 changes: 53 additions & 10 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2285,8 +2285,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// `ExprKind::Paren(ExprKind::Underscore)` and should also be lowered to `GenericArg::Infer`
match c.value.peel_parens().kind {
ExprKind::Underscore => {
let ct_kind = hir::ConstArgKind::Infer(self.lower_span(c.value.span), ());
self.arena.alloc(hir::ConstArg { hir_id: self.lower_node_id(c.id), kind: ct_kind })
let ct_kind = hir::ConstArgKind::Infer(());
self.arena.alloc(hir::ConstArg {
hir_id: self.lower_node_id(c.id),
kind: ct_kind,
span: self.lower_span(c.value.span),
})
}
_ => self.lower_anon_const_to_const_arg_and_alloc(c),
}
Expand Down Expand Up @@ -2356,7 +2360,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
hir::ConstArgKind::Anon(ct)
};

self.arena.alloc(hir::ConstArg { hir_id: self.next_id(), kind: ct_kind })
self.arena.alloc(hir::ConstArg {
hir_id: self.next_id(),
kind: ct_kind,
span: self.lower_span(span),
})
}

fn lower_const_item_rhs(
Expand All @@ -2373,9 +2381,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
let const_arg = ConstArg {
hir_id: self.next_id(),
kind: hir::ConstArgKind::Error(
DUMMY_SP,
self.dcx().span_delayed_bug(DUMMY_SP, "no block"),
),
span: DUMMY_SP,
};
hir::ConstItemRhs::TypeConst(self.arena.alloc(const_arg))
}
Expand All @@ -2388,13 +2396,15 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {

#[instrument(level = "debug", skip(self), ret)]
fn lower_expr_to_const_arg_direct(&mut self, expr: &Expr) -> hir::ConstArg<'hir> {
let span = self.lower_span(expr.span);

let overly_complex_const = |this: &mut Self| {
let e = this.dcx().struct_span_err(
expr.span,
"complex const arguments must be placed inside of a `const` block",
);

ConstArg { hir_id: this.next_id(), kind: hir::ConstArgKind::Error(expr.span, e.emit()) }
ConstArg { hir_id: this.next_id(), kind: hir::ConstArgKind::Error(e.emit()), span }
};

match &expr.kind {
Expand Down Expand Up @@ -2425,8 +2435,26 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
ConstArg {
hir_id: self.next_id(),
kind: hir::ConstArgKind::TupleCall(qpath, lowered_args),
span,
}
}
ExprKind::Tup(exprs) => {
let exprs = self.arena.alloc_from_iter(exprs.iter().map(|expr| {
let expr = if let ExprKind::ConstBlock(anon_const) = &expr.kind {
let def_id = self.local_def_id(anon_const.id);
let def_kind = self.tcx.def_kind(def_id);
assert_eq!(DefKind::AnonConst, def_kind);

self.lower_anon_const_to_const_arg(anon_const)
} else {
self.lower_expr_to_const_arg_direct(&expr)
};

&*self.arena.alloc(expr)
}));

ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Tup(exprs), span }
}
ExprKind::Path(qself, path) => {
let qpath = self.lower_qpath(
expr.id,
Expand All @@ -2439,7 +2467,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
None,
);

ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Path(qpath) }
ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Path(qpath), span }
}
ExprKind::Struct(se) => {
let path = self.lower_qpath(
Expand Down Expand Up @@ -2480,11 +2508,16 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
})
}));

ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Struct(path, fields) }
ConstArg {
hir_id: self.next_id(),
kind: hir::ConstArgKind::Struct(path, fields),
span,
}
}
ExprKind::Underscore => ConstArg {
hir_id: self.lower_node_id(expr.id),
kind: hir::ConstArgKind::Infer(expr.span, ()),
kind: hir::ConstArgKind::Infer(()),
span,
},
ExprKind::Block(block, _) => {
if let [stmt] = block.stmts.as_slice()
Expand All @@ -2495,6 +2528,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
| ExprKind::Path(..)
| ExprKind::Struct(..)
| ExprKind::Call(..)
| ExprKind::Tup(..)
)
{
return self.lower_expr_to_const_arg_direct(expr);
Expand Down Expand Up @@ -2528,7 +2562,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
return match anon.mgca_disambiguation {
MgcaDisambiguation::AnonConst => {
let lowered_anon = self.lower_anon_const_to_anon_const(anon);
ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Anon(lowered_anon) }
ConstArg {
hir_id: self.next_id(),
kind: hir::ConstArgKind::Anon(lowered_anon),
span: lowered_anon.span,
}
}
MgcaDisambiguation::Direct => self.lower_expr_to_const_arg_direct(&anon.value),
};
Expand Down Expand Up @@ -2565,11 +2603,16 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
return ConstArg {
hir_id: self.lower_node_id(anon.id),
kind: hir::ConstArgKind::Path(qpath),
span: self.lower_span(expr.span),
};
}

let lowered_anon = self.lower_anon_const_to_anon_const(anon);
ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Anon(lowered_anon) }
ConstArg {
hir_id: self.next_id(),
kind: hir::ConstArgKind::Anon(lowered_anon),
span: self.lower_span(expr.span),
}
}

/// See [`hir::ConstArg`] for when to use this function vs
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_ast_lowering/src/pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
self.arena.alloc(hir::ConstArg {
hir_id: self.next_id(),
kind: hir::ConstArgKind::Anon(self.arena.alloc(anon_const)),
span,
})
}

Expand Down Expand Up @@ -557,6 +558,6 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
})
});
let hir_id = self.next_id();
self.arena.alloc(hir::ConstArg { kind: hir::ConstArgKind::Anon(ct), hir_id })
self.arena.alloc(hir::ConstArg { kind: hir::ConstArgKind::Anon(ct), hir_id, span })
}
}
5 changes: 4 additions & 1 deletion compiler/rustc_codegen_llvm/messages.ftl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
codegen_llvm_autodiff_component_unavailable = failed to load our autodiff backend. Did you install it via rustup?
codegen_llvm_autodiff_component_missing = autodiff backend not found in the sysroot: {$err}
.note = it will be distributed via rustup in the future

codegen_llvm_autodiff_component_unavailable = failed to load our autodiff backend: {$err}

codegen_llvm_autodiff_without_enable = using the autodiff feature requires -Z autodiff=Enable
codegen_llvm_autodiff_without_lto = using the autodiff feature requires setting `lto="fat"` in your Cargo.toml
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_codegen_llvm/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {

#[derive(Diagnostic)]
#[diag(codegen_llvm_autodiff_component_unavailable)]
pub(crate) struct AutoDiffComponentUnavailable;
pub(crate) struct AutoDiffComponentUnavailable {
pub err: String,
}

#[derive(Diagnostic)]
#[diag(codegen_llvm_autodiff_component_missing)]
#[note]
pub(crate) struct AutoDiffComponentMissing {
pub err: String,
}

#[derive(Diagnostic)]
#[diag(codegen_llvm_autodiff_without_lto)]
Expand Down
10 changes: 8 additions & 2 deletions compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,14 @@ impl CodegenBackend for LlvmCodegenBackend {

use crate::back::lto::enable_autodiff_settings;
if sess.opts.unstable_opts.autodiff.contains(&AutoDiff::Enable) {
if let Err(_) = llvm::EnzymeWrapper::get_or_init(&sess.opts.sysroot) {
sess.dcx().emit_fatal(crate::errors::AutoDiffComponentUnavailable);
match llvm::EnzymeWrapper::get_or_init(&sess.opts.sysroot) {
Ok(_) => {}
Err(llvm::EnzymeLibraryError::NotFound { err }) => {
sess.dcx().emit_fatal(crate::errors::AutoDiffComponentMissing { err });
}
Err(llvm::EnzymeLibraryError::LoadFailed { err }) => {
sess.dcx().emit_fatal(crate::errors::AutoDiffComponentUnavailable { err });
}
}
enable_autodiff_settings(&sess.opts.unstable_opts.autodiff);
}
Expand Down
34 changes: 25 additions & 9 deletions compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pub(crate) mod Enzyme_AD {
fn load_ptr_by_symbol_mut_void(
lib: &libloading::Library,
bytes: &[u8],
) -> Result<*mut c_void, Box<dyn std::error::Error>> {
) -> Result<*mut c_void, libloading::Error> {
unsafe {
let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?;
// libloading = 0.9.0: try_as_raw_ptr always succeeds and returns Some
Expand Down Expand Up @@ -192,15 +192,27 @@ pub(crate) mod Enzyme_AD {

static ENZYME_INSTANCE: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();

#[derive(Debug)]
pub(crate) enum EnzymeLibraryError {
NotFound { err: String },
LoadFailed { err: String },
}

impl From<libloading::Error> for EnzymeLibraryError {
fn from(err: libloading::Error) -> Self {
Self::LoadFailed { err: format!("{err:?}") }
}
}

impl EnzymeWrapper {
/// Initialize EnzymeWrapper with the given sysroot if not already initialized.
/// Safe to call multiple times - subsequent calls are no-ops due to OnceLock.
pub(crate) fn get_or_init(
sysroot: &rustc_session::config::Sysroot,
) -> Result<MutexGuard<'static, Self>, Box<dyn std::error::Error>> {
) -> Result<MutexGuard<'static, Self>, EnzymeLibraryError> {
let mtx: &'static Mutex<EnzymeWrapper> = ENZYME_INSTANCE.get_or_try_init(|| {
let w = Self::call_dynamic(sysroot)?;
Ok::<_, Box<dyn std::error::Error>>(Mutex::new(w))
Ok::<_, EnzymeLibraryError>(Mutex::new(w))
})?;

Ok(mtx.lock().unwrap())
Expand Down Expand Up @@ -351,7 +363,7 @@ pub(crate) mod Enzyme_AD {
#[allow(non_snake_case)]
fn call_dynamic(
sysroot: &rustc_session::config::Sysroot,
) -> Result<Self, Box<dyn std::error::Error>> {
) -> Result<Self, EnzymeLibraryError> {
let enzyme_path = Self::get_enzyme_path(sysroot)?;
let lib = unsafe { libloading::Library::new(enzyme_path)? };

Expand Down Expand Up @@ -416,7 +428,7 @@ pub(crate) mod Enzyme_AD {
})
}

fn get_enzyme_path(sysroot: &Sysroot) -> Result<String, String> {
fn get_enzyme_path(sysroot: &Sysroot) -> Result<String, EnzymeLibraryError> {
let llvm_version_major = unsafe { LLVMRustVersionMajor() };

let path_buf = sysroot
Expand All @@ -434,15 +446,19 @@ pub(crate) mod Enzyme_AD {
.map(|p| p.join("lib").display().to_string())
.collect::<Vec<String>>()
.join("\n* ");
format!(
"failed to find a `libEnzyme-{llvm_version_major}` folder \
EnzymeLibraryError::NotFound {
err: format!(
"failed to find a `libEnzyme-{llvm_version_major}` folder \
in the sysroot candidates:\n* {candidates}"
)
),
}
})?;

Ok(path_buf
.to_str()
.ok_or_else(|| format!("invalid UTF-8 in path: {}", path_buf.display()))?
.ok_or_else(|| EnzymeLibraryError::LoadFailed {
err: format!("invalid UTF-8 in path: {}", path_buf.display()),
})?
.to_string())
}
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_codegen_ssa/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use std::time::{Duration, Instant};
use itertools::Itertools;
use rustc_abi::FIRST_VARIANT;
use rustc_ast::expand::allocator::{
ALLOC_ERROR_HANDLER, ALLOCATOR_METHODS, AllocatorKind, AllocatorMethod, AllocatorTy,
ALLOC_ERROR_HANDLER, ALLOCATOR_METHODS, AllocatorKind, AllocatorMethod, AllocatorMethodInput,
AllocatorTy,
};
use rustc_data_structures::fx::{FxHashMap, FxIndexSet};
use rustc_data_structures::profiling::{get_resident_set_size, print_time_passes_entry};
Expand Down Expand Up @@ -671,7 +672,7 @@ pub fn allocator_shim_contents(tcx: TyCtxt<'_>, kind: AllocatorKind) -> Vec<Allo
methods.push(AllocatorMethod {
name: ALLOC_ERROR_HANDLER,
special: None,
inputs: &[],
inputs: &[AllocatorMethodInput { name: "layout", ty: AllocatorTy::Layout }],
output: AllocatorTy::Never,
});
}
Expand Down
23 changes: 7 additions & 16 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ impl<'hir> ConstItemRhs<'hir> {
pub fn span<'tcx>(&self, tcx: impl crate::intravisit::HirTyCtxt<'tcx>) -> Span {
match self {
ConstItemRhs::Body(body_id) => tcx.hir_body(*body_id).value.span,
ConstItemRhs::TypeConst(ct_arg) => ct_arg.span(),
ConstItemRhs::TypeConst(ct_arg) => ct_arg.span,
}
}
}
Expand All @@ -447,6 +447,7 @@ pub struct ConstArg<'hir, Unambig = ()> {
#[stable_hasher(ignore)]
pub hir_id: HirId,
pub kind: ConstArgKind<'hir, Unambig>,
pub span: Span,
}

impl<'hir> ConstArg<'hir, AmbigArg> {
Expand Down Expand Up @@ -475,7 +476,7 @@ impl<'hir> ConstArg<'hir> {
/// Functions accepting ambiguous consts will not handle the [`ConstArgKind::Infer`] variant, if
/// infer consts are relevant to you then care should be taken to handle them separately.
pub fn try_as_ambig_ct(&self) -> Option<&ConstArg<'hir, AmbigArg>> {
if let ConstArgKind::Infer(_, ()) = self.kind {
if let ConstArgKind::Infer(()) = self.kind {
return None;
}

Expand All @@ -494,23 +495,13 @@ impl<'hir, Unambig> ConstArg<'hir, Unambig> {
_ => None,
}
}

pub fn span(&self) -> Span {
match self.kind {
ConstArgKind::Struct(path, _) => path.span(),
ConstArgKind::Path(path) => path.span(),
ConstArgKind::TupleCall(path, _) => path.span(),
ConstArgKind::Anon(anon) => anon.span,
ConstArgKind::Error(span, _) => span,
ConstArgKind::Infer(span, _) => span,
}
}
}

/// See [`ConstArg`].
#[derive(Clone, Copy, Debug, HashStable_Generic)]
#[repr(u8, C)]
pub enum ConstArgKind<'hir, Unambig = ()> {
Tup(&'hir [&'hir ConstArg<'hir, Unambig>]),
/// **Note:** Currently this is only used for bare const params
/// (`N` where `fn foo<const N: usize>(...)`),
/// not paths to any const (`N` where `const N: usize = ...`).
Expand All @@ -523,10 +514,10 @@ pub enum ConstArgKind<'hir, Unambig = ()> {
/// Tuple constructor variant
TupleCall(QPath<'hir>, &'hir [&'hir ConstArg<'hir>]),
/// Error const
Error(Span, ErrorGuaranteed),
Error(ErrorGuaranteed),
/// This variant is not always used to represent inference consts, sometimes
/// [`GenericArg::Infer`] is used instead.
Infer(Span, Unambig),
Infer(Unambig),
}

#[derive(Clone, Copy, Debug, HashStable_Generic)]
Expand Down Expand Up @@ -572,7 +563,7 @@ impl GenericArg<'_> {
match self {
GenericArg::Lifetime(l) => l.ident.span,
GenericArg::Type(t) => t.span,
GenericArg::Const(c) => c.span(),
GenericArg::Const(c) => c.span,
GenericArg::Infer(i) => i.span,
}
}
Expand Down
Loading
Loading