diff --git a/Cargo.toml b/Cargo.toml index 25c4484..5cf43ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,17 +8,17 @@ exclude = ["Enzyme/enzyme/benchmarks"] authors = ["Martin Robinson "] repository = "https://github.com/martinjrobins/diffsl" -[[bin]] -name = "diffsl" - # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -llvm13-0 = ["inkwell-130", "llvm-sys-130"] -llvm14-0 = ["inkwell-140", "llvm-sys-140"] -llvm15-0 = ["inkwell-150", "llvm-sys-150"] -llvm16-0 = ["inkwell-160", "llvm-sys-160"] -llvm17-0 = ["inkwell-170", "llvm-sys-170"] +llvm13-0 = ["inkwell-130", "llvm-sys-130", "inkwell_internals", "llvm", "enzyme"] +llvm14-0 = ["inkwell-140", "llvm-sys-140", "inkwell_internals", "llvm", "enzyme"] +llvm15-0 = ["inkwell-150", "llvm-sys-150", "inkwell_internals", "llvm", "enzyme"] +llvm16-0 = ["inkwell-160", "llvm-sys-160", "inkwell_internals", "llvm", "enzyme"] +llvm17-0 = ["inkwell-170", "llvm-sys-170", "inkwell_internals", "llvm", "enzyme"] +llvm18-0 = ["inkwell-180", "llvm-sys-180", "inkwell_internals", "llvm", "enzyme"] +enzyme = ["bindgen", "cmake"] +llvm = [] [dependencies] ndarray = { version = ">=0.15.0", features = ["approx-0_5"] } @@ -27,24 +27,31 @@ approx = ">=0.5" pest = ">=2.1.3" pest_derive = ">=2.1.0" itertools = ">=0.10.3" -ouroboros = ">=0.17" clap = { version = "4.3.23", features = ["derive"] } uid = "0.1.7" -inkwell-130 = { package = "inkwell", version = ">=0.4.0", features = ["llvm13-0"], optional = true } -inkwell-140 = { package = "inkwell", version = ">=0.4.0", features = ["llvm14-0"], optional = true } -inkwell-150 = { package = "inkwell", version = ">=0.4.0", features = ["llvm15-0"], optional = true } -inkwell-160 = { package = "inkwell", version = ">=0.4.0", features = ["llvm16-0"], optional = true } -inkwell-170 = { package = "inkwell", version = ">=0.4.0", features = ["llvm17-0"], optional = true } +inkwell-130 = { package = "inkwell", version = ">=0.5.0", features = ["llvm13-0"], optional = true } +inkwell-140 = { package = "inkwell", version = ">=0.5.0", features = ["llvm14-0"], optional = true } +inkwell-150 = { package = "inkwell", version = ">=0.5.0", features = ["llvm15-0"], optional = true } +inkwell-160 = { package = "inkwell", version = ">=0.5.0", features = ["llvm16-0"], optional = true } +inkwell-170 = { package = "inkwell", version = ">=0.5.0", features = ["llvm17-0"], optional = true } +inkwell-180 = { package = "inkwell", version = ">=0.5.0", features = ["llvm18-0"], optional = true } llvm-sys-130 = { package = "llvm-sys", version = "130.0.4", optional = true } llvm-sys-140 = { package = "llvm-sys", version = "140.0.2", optional = true } llvm-sys-150 = { package = "llvm-sys", version = "150.0.3", optional = true } llvm-sys-160 = { package = "llvm-sys", version = "160.1.0", optional = true } llvm-sys-170 = { package = "llvm-sys", version = "170.0.1", optional = true } -inkwell_internals = "0.9.0" +llvm-sys-180 = { package = "llvm-sys", version = "180.0.0", optional = true } +inkwell_internals = { version = "0.9.0", optional = true } +cranelift = "0.110.1" +cranelift-module = "0.110.1" +cranelift-jit = "0.110.1" +cranelift-native = "0.110.1" +target-lexicon = "0.12.16" +aliasable = "0.1.3" [build-dependencies] -cmake = "0.1.50" -bindgen = "0.69.4" +bindgen = { version = "0.69.4", optional = true } +cmake = { version = "0.1.50", optional = true } [dev-dependencies] divan = "0.1.14" diff --git a/benches/evaluation.rs b/benches/evaluation.rs index 5546228..5c2404d 100644 --- a/benches/evaluation.rs +++ b/benches/evaluation.rs @@ -1,4 +1,7 @@ -use diffsl::{discretise::DiscreteModel, execution::Compiler, parser::parse_ds_string}; +use diffsl::{ + discretise::DiscreteModel, execution::module::CodegenModule, parser::parse_ds_string, Compiler, + CraneliftModule, +}; use divan::Bencher; use ndarray::Array1; @@ -6,7 +9,7 @@ fn main() { divan::main(); } -fn setup(n: usize, f_text: &str, name: &str) -> Compiler { +fn setup(n: usize, f_text: &str, name: &str) -> Compiler { let u = vec![1.0; n]; let full_text = format!( " @@ -28,14 +31,32 @@ fn setup(n: usize, f_text: &str, name: &str) -> Compiler { ); let model = parse_ds_string(&full_text).unwrap(); let discrete_model = DiscreteModel::build(name, &model).unwrap(); - let out = format!("test_output/benches_evaluation_{}", name); - Compiler::from_discrete_model(&discrete_model, out.as_str()).unwrap() + Compiler::from_discrete_model(&discrete_model).unwrap() } +#[cfg(feature = "llvm")] #[divan::bench(consts = [1, 10, 100, 1000])] -fn add_scalar_diffsl(bencher: Bencher) { +fn add_scalar_diffsl_llvm(bencher: Bencher) { + use diffsl::LlvmModule; + + let n = N; + let compiler = setup::(n, "u_i + 1.0", "add_scalar"); + let mut data = compiler.get_new_data(); + compiler.set_inputs(&[], data.as_mut_slice()); + let mut u = vec![1.0; n]; + compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); + let mut rr = vec![0.0; n]; + let t = 0.0; + + bencher.bench_local(|| { + compiler.rhs(t, &u, &mut data, &mut rr); + }); +} + +#[divan::bench(consts = [1, 10, 100, 1000])] +fn add_scalar_diffsl_cranelift(bencher: Bencher) { let n = N; - let compiler = setup(n, "u_i + 1.0", "add_scalar"); + let compiler = setup::(n, "u_i + 1.0", "add_scalar"); let mut data = compiler.get_new_data(); compiler.set_inputs(&[], data.as_mut_slice()); let mut u = vec![1.0; n]; @@ -47,6 +68,7 @@ fn add_scalar_diffsl(bencher: Bencher) { compiler.rhs(t, &u, &mut data, &mut rr); }); } + #[divan::bench(consts = [1, 10, 100, 1000])] fn add_scalar_ndarray(bencher: Bencher) { let n = N; diff --git a/build.rs b/build.rs index dec0b56..8559f4a 100644 --- a/build.rs +++ b/build.rs @@ -1,74 +1,83 @@ -use bindgen::{BindgenError, Bindings, Builder}; -use std::{env, path::PathBuf}; +#[cfg(feature = "enzyme")] +mod enzyme { + use bindgen::{BindgenError, Bindings, Builder}; + use std::{env, path::PathBuf}; -fn compile_enzyme(llvm_dir: String) -> (String, String) { - let dst = cmake::Config::new("Enzyme/enzyme") - .define("ENZYME_STATIC_LIB", "ON") - .define("ENZYME_CLANG", "OFF") - .define("LLVM_DIR", llvm_dir) - .define( - "CMAKE_CXX_FLAGS", - "-Wno-comment -Wno-deprecated-declarations", - ) - .build(); - let dst_disp = dst.display(); - let lib_dir = format!("{}/lib", dst_disp); - let inc_dir = "Enzyme/enzyme".to_string(); - (lib_dir, inc_dir) -} + fn compile_enzyme(llvm_dir: String) -> (String, String) { + let dst = cmake::Config::new("Enzyme/enzyme") + .define("ENZYME_STATIC_LIB", "ON") + .define("ENZYME_CLANG", "OFF") + .define("LLVM_DIR", llvm_dir) + .define( + "CMAKE_CXX_FLAGS", + "-Wno-comment -Wno-deprecated-declarations", + ) + .build(); + let dst_disp = dst.display(); + let lib_dir = format!("{}/lib", dst_disp); + let inc_dir = "Enzyme/enzyme".to_string(); + (lib_dir, inc_dir) + } -fn enzyme_bindings(inc_dirs: &[String]) -> Result { - let mut builder = Builder::default() - .header("wrapper.h") - .generate_comments(false) - .clang_arg("-x") - .clang_arg("c++"); + fn enzyme_bindings(inc_dirs: &[String]) -> Result { + let mut builder = Builder::default() + .header("wrapper.h") + .generate_comments(false) + .clang_arg("-x") + .clang_arg("c++"); - // add include dirs - for dir in inc_dirs { - builder = builder.clang_arg(format!("-I{}", dir)) + // add include dirs + for dir in inc_dirs { + builder = builder.clang_arg(format!("-I{}", dir)) + } + builder.generate() } - builder.generate() -} -fn main() { - // get env vars matching DEP_LLVM_*_LIBDIR regex - let llvm_dirs: Vec<_> = env::vars() - .filter(|(k, _)| k.starts_with("DEP_LLVM_") && k.ends_with("_LIBDIR")) - .collect(); - // take first one - let llvm_lib_dir = llvm_dirs - .first() - .expect("DEP_LLVM_*_LIBDIR not set") - .1 - .clone(); - let llvm_env_key = llvm_dirs.first().unwrap().0.clone(); - let llvm_version = &llvm_env_key["DEP_LLVM_".len()..(llvm_env_key.len() - "_LIBDIR".len())]; - dbg!(llvm_version); + pub fn enzyme_main() { + // get env vars matching DEP_LLVM_*_LIBDIR regex + let llvm_dirs: Vec<_> = env::vars() + .filter(|(k, _)| k.starts_with("DEP_LLVM_") && k.ends_with("_LIBDIR")) + .collect(); + // take first one + let llvm_lib_dir = llvm_dirs + .first() + .expect("DEP_LLVM_*_LIBDIR not set") + .1 + .clone(); + let llvm_env_key = llvm_dirs.first().unwrap().0.clone(); + let llvm_version = &llvm_env_key["DEP_LLVM_".len()..(llvm_env_key.len() - "_LIBDIR".len())]; + dbg!(llvm_version); - // replace last "lib" with "include" - let llvm_inc_dir = llvm_lib_dir - .chars() - .take(llvm_lib_dir.len() - 3) - .collect::() - + "include"; + // replace last "lib" with "include" + let llvm_inc_dir = llvm_lib_dir + .chars() + .take(llvm_lib_dir.len() - 3) + .collect::() + + "include"; - // compile enzyme - let (libdir, incdir) = compile_enzyme(llvm_lib_dir.clone()); - let libnames = [format!("EnzymeStatic-{}", llvm_version)]; + // compile enzyme + let (libdir, incdir) = compile_enzyme(llvm_lib_dir.clone()); + let libnames = [format!("EnzymeStatic-{}", llvm_version)]; - // bind enzyme api - let bindings_rs = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs"); - let bindings = enzyme_bindings(&[llvm_inc_dir, incdir]).expect("Couldn't generate bindings!"); - bindings - .write_to_file(bindings_rs) - .expect("Couldn't write file bindings.rs!"); + // bind enzyme api + let bindings_rs = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs"); + let bindings = + enzyme_bindings(&[llvm_inc_dir, incdir]).expect("Couldn't generate bindings!"); + bindings + .write_to_file(bindings_rs) + .expect("Couldn't write file bindings.rs!"); - println!("cargo:rustc-link-search=native={}", libdir); - println!("cargo:rustc-link-search=native={}", llvm_lib_dir); - for libname in libnames.iter() { - println!("cargo:rustc-link-lib={}", libname); + println!("cargo:rustc-link-search=native={}", libdir); + println!("cargo:rustc-link-search=native={}", llvm_lib_dir); + for libname in libnames.iter() { + println!("cargo:rustc-link-lib={}", libname); + } + println!("cargo:rustc-link-lib=LLVMDemangle"); + println!("cargo:rerun-if-changed=wrapper.h"); } - println!("cargo:rustc-link-lib=LLVMDemangle"); - println!("cargo:rerun-if-changed=wrapper.h"); +} + +fn main() { + #[cfg(feature = "enzyme")] + enzyme::enzyme_main(); } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index e134806..6aab7f5 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -69,13 +69,17 @@ pub struct RateEquation<'a> { } #[derive(Debug, Clone)] -pub struct IndexedName<'a> { +pub struct Name<'a> { pub name: &'a str, pub indices: Vec, + pub is_tangent: bool, } -impl<'a> fmt::Display for IndexedName<'a> { +impl<'a> fmt::Display for Name<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.is_tangent { + write!(f, "d_")?; + } write!(f, "{}", self.name)?; if !self.indices.is_empty() { write!(f, "_")?; @@ -140,6 +144,7 @@ pub struct Monop<'a> { pub struct Call<'a> { pub fn_name: &'a str, pub args: Vec>>, + pub is_tangent: bool, } #[derive(Debug, Clone)] @@ -303,10 +308,9 @@ pub enum AstKind<'a> { CallArg(CallArg<'a>), Index(Index<'a>), Slice(Slice<'a>), - IndexedName(IndexedName<'a>), + Name(Name<'a>), Number(f64), Integer(i64), - Name(&'a str), NamedGradient(NamedGradient<'a>), } @@ -365,10 +369,9 @@ impl<'a> AstKind<'a> { _ => None, } } - pub fn as_name(&self) -> Option<&str> { + pub fn as_name(&self) -> Option<&Name> { match self { AstKind::Name(n) => Some(n), - AstKind::IndexedName(n) => Some(n.name), _ => None, } } @@ -433,6 +436,7 @@ impl<'a> AstKind<'a> { AstKind::Call(Call { fn_name: "dot", args: vec![Box::new(child)], + is_tangent: false, }) } pub fn new_indice(first: Ast<'a>, last: Option>, sep: Option<&'a str>) -> Self { @@ -453,16 +457,31 @@ impl<'a> AstKind<'a> { data: data.into_iter().map(Box::new).collect(), }) } + pub fn new_indexed_name(name: &'a str, indices: Vec) -> Self { + AstKind::Name(Name { + name, + indices, + is_tangent: false, + }) + } pub fn new_name(name: &'a str) -> Self { - AstKind::Name(name) + AstKind::Name(Name { + name, + indices: Vec::new(), + is_tangent: false, + }) } - pub fn new_indexed_name(name: &'a str, indices: Vec) -> Self { - AstKind::IndexedName(IndexedName { name, indices }) + pub fn new_tangent_indexed_name(name: &'a str, indices: Vec) -> Self { + AstKind::Name(Name { + name, + indices, + is_tangent: true, + }) } - pub fn new_time_derivative(name: &'a str) -> Self { + pub fn new_time_derivative(name: &'a str, indices: Vec) -> Self { AstKind::NamedGradient(NamedGradient { gradient_of: Box::new(Ast { - kind: Self::new_name(name), + kind: Self::new_indexed_name(name, indices), span: None, }), gradient_wrt: Box::new(Ast { @@ -524,6 +543,114 @@ pub struct Ast<'a> { } impl<'a> Ast<'a> { + pub fn tangent(&self) -> Self { + match &self.kind { + AstKind::Binop(binop) => match binop.op { + '+' | '-' => Self::new_binop(binop.op, binop.left.tangent(), binop.right.tangent()), + '*' => { + let lhs = + Self::new_binop('*', binop.left.as_ref().clone(), binop.right.tangent()); + let rhs = + Self::new_binop('*', binop.left.tangent(), binop.right.as_ref().clone()); + Self::new_binop('+', lhs, rhs) + } + '/' => { + let left = + Self::new_binop('/', binop.left.tangent(), binop.right.as_ref().clone()); + let right_top = + Self::new_binop('*', binop.left.as_ref().clone(), binop.right.tangent()); + let right_bottom = Self::new_binop( + '*', + binop.right.as_ref().clone(), + binop.right.as_ref().clone(), + ); + let right = Self::new_binop('/', right_top, right_bottom); + Self::new_binop('-', left, right) + } + _ => panic!("Tangent not implemented for operator {}", binop.op), + }, + AstKind::Monop(monop) => match monop.op { + '-' => Self::new_monop('-', monop.child.tangent()), + _ => panic!("Tangent not implemented for operator {}", monop.op), + }, + AstKind::Call(call) => { + let mut args = Vec::new(); + for arg in call.args.iter() { + args.push(arg.as_ref().clone()); + args.push(arg.tangent()); + } + Self::new_call(call.fn_name, args, true) + } + AstKind::CallArg(arg) => Self::new_call_arg(arg.name, arg.expression.tangent()), + AstKind::Name(name) => { + if name.name == "t" { + Self::new_number(0.0) + } else { + Self::new_name(name.name, name.indices.clone(), true) + } + } + AstKind::Number(_) => Self::new_number(0.0), + AstKind::NamedGradient(gradient) => { + let gradient_of = gradient.gradient_of.tangent(); + let gradient_wrt = gradient.gradient_wrt.as_ref().clone(); + Self::new_named_gradient(gradient_of, gradient_wrt) + } + _ => panic!("Tangent not implemented for {:?}", self.kind), + } + } + + pub fn new_named_gradient(gradient_of: Ast<'a>, gradient_wrt: Ast<'a>) -> Self { + Ast { + kind: AstKind::NamedGradient(NamedGradient { + gradient_of: Box::new(gradient_of), + gradient_wrt: Box::new(gradient_wrt), + }), + span: None, + } + } + + pub fn new_name(name: &'a str, indices: Vec, is_tangent: bool) -> Self { + Ast { + kind: AstKind::Name(Name { + name, + indices, + is_tangent, + }), + span: None, + } + } + + pub fn new_call_arg(name: Option<&'a str>, expression: Ast<'a>) -> Self { + Ast { + kind: AstKind::CallArg(CallArg { + name, + expression: Box::new(expression), + }), + span: None, + } + } + + pub fn new_call(fn_name: &'a str, args: Vec>, is_tangent: bool) -> Self { + Ast { + kind: AstKind::Call(Call { + fn_name, + args: args.into_iter().map(Box::new).collect(), + is_tangent, + }), + span: None, + } + } + + pub fn new_monop(op: char, child: Ast<'a>) -> Self { + Ast { + kind: AstKind::Monop(Monop { + op, + child: Box::new(child), + }), + span: None, + } + } + pub fn new_binop(op: char, lhs: Ast<'a>, rhs: Ast<'a>) -> Self { Ast { kind: AstKind::new_binop(op, lhs, rhs), @@ -531,6 +658,13 @@ impl<'a> Ast<'a> { } } + pub fn new_number(num: f64) -> Self { + Ast { + kind: AstKind::Number(num), + span: None, + } + } + pub fn clone_and_subst<'b>(&self, replacements: &HashMap<&'a str, &'b Ast<'a>>) -> Self { let cloned_kind = match &self.kind { AstKind::Definition(dfn) => AstKind::Definition(Definition { @@ -569,6 +703,7 @@ impl<'a> Ast<'a> { .iter() .map(|m| Box::new(m.clone_and_subst(replacements))) .collect(), + is_tangent: call.is_tangent, }), AstKind::CallArg(arg) => AstKind::CallArg(CallArg { name: arg.name, @@ -576,15 +711,15 @@ impl<'a> Ast<'a> { }), AstKind::Number(num) => AstKind::Number(*num), AstKind::Integer(num) => AstKind::Integer(*num), - AstKind::IndexedName(name) => AstKind::IndexedName(IndexedName { - name: name.name, - indices: name.indices.clone(), - }), AstKind::Name(name) => { - if let Some(x) = replacements.get(name) { - x.kind.clone() + if name.indices.is_empty() && replacements.contains_key(name.name) { + replacements[name.name].kind.clone() } else { - AstKind::Name(name) + AstKind::Name(Name { + name: name.name, + indices: name.indices.clone(), + is_tangent: name.is_tangent, + }) } } AstKind::NamedGradient(gradient) => AstKind::NamedGradient(NamedGradient { @@ -654,10 +789,11 @@ impl<'a> Ast<'a> { AstKind::CallArg(arg) => { arg.expression.collect_deps(deps); } - AstKind::IndexedName(found_name) => { - deps.insert(found_name.name); - } - AstKind::Name(found_name) => { + AstKind::Name(Name { + name: found_name, + indices: _, + is_tangent: _, + }) => { deps.insert(found_name); } AstKind::NamedGradient(gradient) => { @@ -728,7 +864,7 @@ impl<'a> Ast<'a> { AstKind::CallArg(arg) => { arg.expression.collect_indices(indices); } - AstKind::IndexedName(found_name) => { + AstKind::Name(found_name) => { indices.extend(found_name.indices.iter().cloned()); } AstKind::Index(index) => { @@ -751,7 +887,6 @@ impl<'a> Ast<'a> { gradient.gradient_of.collect_indices(indices); gradient.gradient_wrt.collect_indices(indices); } - AstKind::Name(_) => (), AstKind::DsModel(_) => (), AstKind::Number(_) => (), AstKind::Integer(_) => (), @@ -778,8 +913,9 @@ impl<'a> fmt::Display for Ast<'a> { model.name, model.unknowns, model.statements ) } - AstKind::Name(name) => write!(f, "{}", name), - AstKind::IndexedName(name) => write!(f, "{}", name), + AstKind::Name(name) => { + write!(f, "{}", name) + } AstKind::Number(num) => write!(f, "{}", num), AstKind::Integer(num) => write!(f, "{}", num), AstKind::Unknown(unknown) => write!( diff --git a/src/bin/diffsl.rs b/src/bin/diffsl.rs deleted file mode 100644 index fb01e12..0000000 --- a/src/bin/diffsl.rs +++ /dev/null @@ -1,46 +0,0 @@ -use anyhow::Result; -use clap::Parser; -use diffsl::{compile, CompilerOptions}; - -/// compiles a model in continuous (.cs) or discrete (.ds) format to an object file -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Input filename - input: String, - - /// Output filename - #[arg(short, long)] - out: Option, - - /// Model to build (only for continuous model files) - #[arg(short, long)] - model: Option, - - /// Compile object file only - #[arg(short, long)] - compile: bool, - - /// Compile to WASM - #[arg(short, long)] - wasm: bool, - - /// Compile to standalone executable - #[arg(short, long)] - standalone: bool, -} - -fn main() -> Result<()> { - let cli = Args::parse(); - let options = CompilerOptions { - bitcode_only: cli.compile, - wasm: cli.wasm, - standalone: cli.standalone, - }; - compile( - &cli.input, - cli.out.as_deref(), - cli.model.as_deref(), - options, - ) -} diff --git a/src/continuous/builder.rs b/src/continuous/builder.rs index 99ccad9..185c138 100644 --- a/src/continuous/builder.rs +++ b/src/continuous/builder.rs @@ -126,7 +126,7 @@ impl<'s> Variable<'s> { let bounds = match &unknown.codomain { Some(r) => match &r.kind { AstKind::Range(r) => (r.lower, r.upper), - AstKind::Name(name) => match *name { + AstKind::Name(name) => match name.name { "NonNegative" => (0.0, f64::INFINITY), "R" => (-f64::INFINITY, f64::INFINITY), _ => { @@ -237,7 +237,12 @@ impl<'s> ModelInfo<'s> { // - the lhs is a call with a name equal to one of the variables, // - that variable has a dependent t, // - there is a number equal to the lower bound of time in the argument corresponding to time - if let AstKind::Call(ast::Call { fn_name, args }) = &eqn.lhs.kind { + if let AstKind::Call(ast::Call { + fn_name, + args, + is_tangent: _, + }) = &eqn.lhs.kind + { if let Some(v_cell) = self.variables.get(fn_name) { let v = v_cell.borrow(); if let Some(time_index) = v.time_index { @@ -516,7 +521,7 @@ impl<'s> ModelInfo<'s> { if !self .variables .iter() - .any(|(var_name, _)| *var_name == *name) + .any(|(var_name, _)| *var_name == name.name) { self.errors .push(Output::new(format!("name {} not found", name), expr.span)) diff --git a/src/discretise/discrete_model.rs b/src/discretise/discrete_model.rs index 1c49fb6..1c05331 100644 --- a/src/discretise/discrete_model.rs +++ b/src/discretise/discrete_model.rs @@ -387,7 +387,9 @@ impl<'s> DiscreteModel<'s> { let geq_deps = geq.get_dependents(); ret.is_algebraic.push(true); if let Some(sp_name) = sp.name() { - if Some(sp_name) == feq.kind.as_name() && !geq_deps.contains(sp_name) { + if Some(sp_name) == feq.kind.as_name().map(|n| n.name) + && !geq_deps.contains(sp_name) + { ret.is_algebraic[i] = false; } } @@ -453,7 +455,10 @@ impl<'s> DiscreteModel<'s> { panic!("state var should have an equation") }; let (f_astkind, g_astkind) = match ast_eqn.kind { - AstKind::RateEquation(eqn) => (AstKind::new_time_derivative(state.name), eqn.rhs.kind), + AstKind::RateEquation(eqn) => ( + AstKind::new_time_derivative(state.name, vec![]), + eqn.rhs.kind, + ), AstKind::Equation(eqn) => ( AstKind::new_num(0.0), AstKind::new_binop('-', *eqn.rhs, *eqn.lhs), @@ -499,7 +504,7 @@ impl<'s> DiscreteModel<'s> { kind: AstKind::new_num(0.0), span: None, }; - let named_gradient_str = AstKind::new_time_derivative(state.name) + let named_gradient_str = AstKind::new_time_derivative(state.name, vec![]) .as_named_gradient() .unwrap() .to_string(); diff --git a/src/discretise/env.rs b/src/discretise/env.rs index de5dc34..fc0dffd 100644 --- a/src/discretise/env.rs +++ b/src/discretise/env.rs @@ -245,10 +245,7 @@ impl Env { AstKind::Number(_) => Some(Layout::new_scalar()), AstKind::Integer(_) => Some(Layout::new_scalar()), AstKind::Domain(d) => Some(Layout::new_dense(Shape::zeros(1) + d.dim)), - AstKind::IndexedName(name) => { - self.get_layout_name(name.name, ast, &name.indices, indices) - } - AstKind::Name(name) => self.get_layout_name(name, ast, &[], indices), + AstKind::Name(name) => self.get_layout_name(name.name, ast, &name.indices, indices), _ => panic!("unrecognised ast node {:#?}", ast.kind), } } diff --git a/src/discretise/tensor.rs b/src/discretise/tensor.rs index 49ad3fa..4ff671f 100644 --- a/src/discretise/tensor.rs +++ b/src/discretise/tensor.rs @@ -17,6 +17,7 @@ pub struct TensorBlock<'s> { layout: RcLayout, expr_layout: RcLayout, expr: Ast<'s>, + tangent_expr: Ast<'s>, } impl<'s> TensorBlock<'s> { @@ -34,6 +35,7 @@ impl<'s> TensorBlock<'s> { indices, layout, expr_layout, + tangent_expr: expr.tangent(), expr, } } @@ -44,6 +46,7 @@ impl<'s> TensorBlock<'s> { start: Index::from_vec(vec![start]), layout: layout.clone(), expr_layout: layout, + tangent_expr: expr.tangent(), expr, indices: Vec::new(), } @@ -69,6 +72,10 @@ impl<'s> TensorBlock<'s> { &self.expr } + pub fn tangent_expr(&self) -> &Ast<'s> { + &self.tangent_expr + } + pub fn rank(&self) -> usize { self.shape().len() } diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index 1bea08f..2ef3e47 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -1,437 +1,160 @@ -use anyhow::anyhow; -use inkwell::{ - passes::PassBuilderOptions, - targets::{CodeModel, InitializationConfig, RelocMode, Target, TargetMachine}, +use crate::{ + discretise::DiscreteModel, + execution::interface::{ + CalcOutFunc, GetDimsFunc, GetOutFunc, MassFunc, RhsFunc, SetIdFunc, SetInputsFunc, + StopFunc, U0Func, + }, + parser::parse_ds_string, }; -use std::env; -use std::path::Path; -use uid::Id; - -use crate::discretise::DiscreteModel; -use crate::parser::parse_ds_string; -use crate::utils::find_executable; -use crate::utils::find_runtime_path; -use anyhow::Result; -use inkwell::{ - context::Context, - execution_engine::{ExecutionEngine, JitFunction, UnsafeFunctionPointer}, - targets::{FileType, TargetTriple}, - OptimizationLevel, -}; -use ouroboros::self_referencing; -use std::process::Command; -use super::codegen::CompileGradientArgType; -use super::codegen::GetDimsFunc; -use super::codegen::GetOutFunc; -use super::codegen::SetIdFunc; -use super::codegen::SetInputsFunc; -use super::codegen::SetInputsGradientFunc; -use super::codegen::U0GradientFunc; -use super::codegen::{CalcOutGradientFunc, MassFunc, RhsFunc, RhsGradientFunc}; use super::{ - codegen::{CalcOutFunc, StopFunc, U0Func}, - data_layout::DataLayout, - CodeGen, + interface::{CalcOutGradientFunc, RhsGradientFunc, SetInputsGradientFunc, U0GradientFunc}, + module::CodegenModule, }; +use anyhow::Result; +use target_lexicon::Triple; +use uid::Id; -struct JitFunctions<'ctx> { - set_u0: JitFunction<'ctx, U0Func>, - rhs: JitFunction<'ctx, RhsFunc>, - mass: JitFunction<'ctx, MassFunc>, - calc_out: JitFunction<'ctx, CalcOutFunc>, - calc_stop: JitFunction<'ctx, StopFunc>, - set_id: JitFunction<'ctx, SetIdFunc>, - get_dims: JitFunction<'ctx, GetDimsFunc>, - set_inputs: JitFunction<'ctx, SetInputsFunc>, - get_out: JitFunction<'ctx, GetOutFunc>, -} - -struct JitGradFunctions<'ctx> { - set_u0_grad: JitFunction<'ctx, U0GradientFunc>, - rhs_grad: JitFunction<'ctx, RhsGradientFunc>, - calc_out_grad: JitFunction<'ctx, CalcOutGradientFunc>, - set_inputs_grad: JitFunction<'ctx, SetInputsGradientFunc>, +struct JitFunctions { + set_u0: U0Func, + rhs: RhsFunc, + mass: MassFunc, + calc_out: CalcOutFunc, + calc_stop: StopFunc, + set_id: SetIdFunc, + get_dims: GetDimsFunc, + set_inputs: SetInputsFunc, + get_out: GetOutFunc, } -struct CompilerData<'ctx> { - codegen: CodeGen<'ctx>, - jit_functions: JitFunctions<'ctx>, - jit_grad_functions: JitGradFunctions<'ctx>, +struct JitGradFunctions { + set_u0_grad: U0GradientFunc, + rhs_grad: RhsGradientFunc, + calc_out_grad: CalcOutGradientFunc, + set_inputs_grad: SetInputsGradientFunc, } -#[self_referencing] -pub struct Compiler { - context: Context, - - #[borrows(context)] - #[not_covariant] - data: CompilerData<'this>, +pub struct Compiler { + module: M, + jit_functions: JitFunctions, + jit_grad_functions: JitGradFunctions, number_of_states: usize, number_of_parameters: usize, number_of_outputs: usize, has_mass: bool, - data_layout: DataLayout, - output_base_filename: String, } -impl Compiler { - const OPT_VARIENTS: [&'static str; 2] = ["opt-14", "opt"]; - const CLANG_VARIENTS: [&'static str; 2] = ["clang", "clang-14"]; - fn find_opt() -> Result<&'static str> { - find_executable(&Compiler::OPT_VARIENTS) - } - fn find_clang() -> Result<&'static str> { - find_executable(&Compiler::CLANG_VARIENTS) - } - /// search for the enzyme library in the environment variables - fn find_enzyme_lib() -> Result { - let env_vars = ["LD_LIBRARY_PATH", "DYLD_LIBRARY_PATH", "PATH"]; - for var in env_vars.iter() { - if let Ok(val) = env::var(var) { - for path in val.split(':') { - // check that LLVMEnzype*.so exists in this directory - if let Ok(entries) = std::fs::read_dir(path) { - for entry in entries.flatten() { - if let Some(filename) = entry.file_name().to_str() { - if filename.starts_with("LLVMEnzyme") && filename.ends_with(".so") { - return Ok(entry.path().to_str().unwrap().to_owned()); - } - } - } - } - } - } - } - Err(anyhow!( - "LLVMEnzyme*.so not found in any of: {:?}", - env_vars - )) - } +impl Compiler { pub fn from_discrete_str(code: &str) -> Result { let uid = Id::::new(); let name = format!("diffsl_{}", uid); let model = parse_ds_string(code).unwrap(); let model = DiscreteModel::build(name.as_str(), &model) .unwrap_or_else(|e| panic!("{}", e.as_error_message(code))); - let dir = env::temp_dir(); - let path = dir.join(name.clone()); - Compiler::from_discrete_model(&model, path.to_str().unwrap()) + Self::from_discrete_model(&model) } - pub fn from_discrete_model(model: &DiscreteModel, out: &str) -> Result { + pub fn from_discrete_model(model: &DiscreteModel) -> Result { let number_of_states = *model.state().shape().first().unwrap_or(&1); let input_names = model .inputs() .iter() .map(|input| input.name().to_owned()) .collect::>(); - let data_layout = DataLayout::new(model); - let context = Context::create(); + let mut module = M::new(Triple::host(), model)?; let number_of_parameters = input_names.iter().fold(0, |acc, name| { - acc + data_layout.get_data_length(name).unwrap() + acc + module.layout().get_data_length(name).unwrap() }); - let number_of_outputs = data_layout.get_data_length("out").unwrap(); + let number_of_outputs = module.layout().get_data_length("out").unwrap(); let has_mass = model.lhs().is_some(); - CompilerTryBuilder { - data_layout, - number_of_states, - number_of_parameters, - number_of_outputs, - context, - has_mass, - output_base_filename: out.to_owned(), - data_builder: |context| { - let module = context.create_module(model.name()); - let real_type = context.f64_type(); - let real_type_str = "f64"; - - let mut codegen = CodeGen::new(model, context, module, real_type, real_type_str); - - let _set_u0 = codegen.compile_set_u0(model)?; - let _calc_stop = codegen.compile_calc_stop(model)?; - let _rhs = codegen.compile_rhs(model)?; - let _mass = codegen.compile_mass(model)?; - let _calc_out = codegen.compile_calc_out(model)?; - let _set_id = codegen.compile_set_id(model)?; - let _get_dims = codegen.compile_get_dims(model)?; - let _set_inputs = codegen.compile_set_inputs(model)?; - let _get_output = codegen.compile_get_tensor(model, "out")?; - // optimise at -O2 no unrolling before giving to enzyme - let pass_options = PassBuilderOptions::create(); - //pass_options.set_verify_each(true); - //pass_options.set_debug_logging(true); - //pass_options.set_loop_interleaving(true); - pass_options.set_loop_vectorization(false); - pass_options.set_loop_slp_vectorization(false); - pass_options.set_loop_unrolling(false); - //pass_options.set_forget_all_scev_in_loop_unroll(true); - //pass_options.set_licm_mssa_opt_cap(1); - //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10); - //pass_options.set_call_graph_profile(true); - //pass_options.set_merge_functions(true); - - let initialization_config = &InitializationConfig::default(); - Target::initialize_all(initialization_config); - let triple = TargetMachine::get_default_triple(); - let target = Target::from_triple(&triple).unwrap(); - let machine = target - .create_target_machine( - &triple, - "generic", //TargetMachine::get_host_cpu_name().to_string().as_str(), - "", //TargetMachine::get_host_cpu_features().to_string().as_str(), - inkwell::OptimizationLevel::Default, - inkwell::targets::RelocMode::Default, - inkwell::targets::CodeModel::Default, - ) - .unwrap(); - - codegen - .module() - .run_passes("default", &machine, pass_options) - .unwrap(); - - let _rhs_grad = codegen.compile_gradient( - _rhs, - &[ - CompileGradientArgType::Const, - CompileGradientArgType::Dup, - CompileGradientArgType::Dup, - CompileGradientArgType::DupNoNeed, - ], - )?; - let _set_inputs_grad = codegen.compile_gradient( - _set_inputs, - &[CompileGradientArgType::Dup, CompileGradientArgType::Dup], - )?; - let _calc_out_grad = codegen.compile_gradient( - _calc_out, - &[ - CompileGradientArgType::Const, - CompileGradientArgType::Dup, - CompileGradientArgType::Dup, - ], - )?; - let _set_u0_grad = codegen.compile_gradient( - _set_u0, - &[CompileGradientArgType::Dup, CompileGradientArgType::Dup], - )?; - - let ee = codegen - .module() - .create_jit_execution_engine(OptimizationLevel::Aggressive) - .map_err(|e| anyhow::anyhow!("Error creating execution engine: {:?}", e))?; - - let set_u0 = Compiler::jit("set_u0", &ee)?; - let rhs = Compiler::jit("rhs", &ee)?; - let mass = Compiler::jit("mass", &ee)?; - let calc_stop = Compiler::jit("calc_stop", &ee)?; - let calc_out = Compiler::jit("calc_out", &ee)?; - let set_id = Compiler::jit("set_id", &ee)?; - let get_dims = Compiler::jit("get_dims", &ee)?; - let set_inputs = Compiler::jit("set_inputs", &ee)?; - let get_out = Compiler::jit("get_out", &ee)?; - - let set_inputs_grad = Compiler::jit("set_inputs_grad", &ee)?; - let calc_out_grad = Compiler::jit("calc_out_grad", &ee)?; - let rhs_grad = Compiler::jit("rhs_grad", &ee)?; - let set_u0_grad = Compiler::jit("set_u0_grad", &ee)?; - - let data = CompilerData { - codegen, - jit_functions: JitFunctions { - set_u0, - rhs, - mass, - calc_out, - set_id, - get_dims, - set_inputs, - get_out, - calc_stop, - }, - jit_grad_functions: JitGradFunctions { - set_u0_grad, - rhs_grad, - calc_out_grad, - set_inputs_grad, - }, - }; - Ok(data) - }, - } - .try_build() - } - - pub fn compile(&self, standalone: bool, wasm: bool) -> Result<()> { - let opt_name = Compiler::find_opt()?; - let clang_name = Compiler::find_clang()?; - let enzyme_lib = Compiler::find_enzyme_lib()?; - let out = self.borrow_output_base_filename(); - let object_filename = Compiler::get_object_filename(out); - let bitcodefilename = Compiler::get_bitcode_filename(out); - let mut command = Command::new(clang_name); - command - .arg(bitcodefilename.as_str()) - .arg("-c") - .arg(format!("-fplugin={}", enzyme_lib)) - .arg("-o") - .arg(object_filename.as_str()); - - if wasm { - command.arg("-target").arg("wasm32-unknown-emscripten"); - } - - let output = command.output().unwrap(); - - if let Some(code) = output.status.code() { - if code != 0 { - println!("{}", String::from_utf8_lossy(&output.stderr)); - return Err(anyhow!("{} returned error code {}", opt_name, code)); - } - } - - // link the object file and our runtime library - let mut command = if wasm { - let emcc_varients = ["emcc"]; - let command_name = find_executable(&emcc_varients)?; - let exported_functions = vec![ - "Vector_destroy", - "Vector_create", - "Vector_create_with_capacity", - "Vector_push", - "Options_destroy", - "Options_create", - "Sundials_destroy", - "Sundials_create", - "Sundials_init", - "Sundials_solve", - ]; - let mut linked_files = vec![ - "libdiffeq_runtime_lib.a", - "libsundials_idas.a", - "libsundials_sunlinsolklu.a", - "libklu.a", - "libamd.a", - "libcolamd.a", - "libbtf.a", - "libsuitesparseconfig.a", - "libsundials_sunmatrixsparse.a", - "libargparse.a", - ]; - if standalone { - linked_files.push("libdiffeq_runtime_wasm.a"); - } - let linked_files = linked_files; - let runtime_path = find_runtime_path(&linked_files)?; - let mut command = Command::new(command_name); - command.arg("-o").arg(out).arg(object_filename.as_str()); - for file in linked_files { - command.arg(Path::new(runtime_path.as_str()).join(file)); - } - if !standalone { - let exported_functions = exported_functions - .into_iter() - .map(|s| format!("_{}", s)) - .collect::>() - .join(","); - command - .arg("-s") - .arg(format!("EXPORTED_FUNCTIONS={}", exported_functions)); - command.arg("--no-entry"); - } - command - } else { - let mut command = Command::new(clang_name); - command.arg("-o").arg(out).arg(object_filename.as_str()); - if standalone { - command.arg("-ldiffeq_runtime"); - } else { - command.arg("-ldiffeq_runtime_lib"); - } - command + let set_u0 = module.compile_set_u0(model)?; + let calc_stop = module.compile_calc_stop(model)?; + let rhs = module.compile_rhs(model)?; + let mass = module.compile_mass(model)?; + let calc_out = module.compile_calc_out(model)?; + let set_id = module.compile_set_id(model)?; + let get_dims = module.compile_get_dims(model)?; + let set_inputs = module.compile_set_inputs(model)?; + let get_output = module.compile_get_tensor(model, "out")?; + + module.pre_autodiff_optimisation()?; + + let set_u0_grad = module.compile_set_u0_grad(&set_u0, model)?; + let rhs_grad = module.compile_rhs_grad(&rhs, model)?; + let calc_out_grad = module.compile_calc_out_grad(&calc_out, model)?; + let set_inputs_grad = module.compile_set_inputs_grad(&set_inputs, model)?; + + module.post_autodiff_optimisation()?; + + let set_u0 = unsafe { std::mem::transmute::<*const u8, U0Func>(module.jit(set_u0)?) }; + let rhs = unsafe { std::mem::transmute::<*const u8, RhsFunc>(module.jit(rhs)?) }; + let mass = unsafe { std::mem::transmute::<*const u8, MassFunc>(module.jit(mass)?) }; + let calc_out = + unsafe { std::mem::transmute::<*const u8, CalcOutFunc>(module.jit(calc_out)?) }; + let calc_stop = + unsafe { std::mem::transmute::<*const u8, StopFunc>(module.jit(calc_stop)?) }; + let set_id = unsafe { std::mem::transmute::<*const u8, SetIdFunc>(module.jit(set_id)?) }; + let get_dims = + unsafe { std::mem::transmute::<*const u8, GetDimsFunc>(module.jit(get_dims)?) }; + let set_inputs = + unsafe { std::mem::transmute::<*const u8, SetInputsFunc>(module.jit(set_inputs)?) }; + let get_out = + unsafe { std::mem::transmute::<*const u8, GetOutFunc>(module.jit(get_output)?) }; + + let set_u0_grad = + unsafe { std::mem::transmute::<*const u8, U0GradientFunc>(module.jit(set_u0_grad)?) }; + let rhs_grad = + unsafe { std::mem::transmute::<*const u8, RhsGradientFunc>(module.jit(rhs_grad)?) }; + let calc_out_grad = unsafe { + std::mem::transmute::<*const u8, CalcOutGradientFunc>(module.jit(calc_out_grad)?) }; - - let output = command.output(); - - let output = match output { - Ok(output) => output, - Err(e) => { - let args = command - .get_args() - .map(|s| s.to_str().unwrap()) - .collect::>() - .join(" "); - println!( - "{} {}", - command.get_program().to_os_string().to_str().unwrap(), - args - ); - return Err(anyhow!("Error linking in runtime: {}", e)); - } + let set_inputs_grad = unsafe { + std::mem::transmute::<*const u8, SetInputsGradientFunc>(module.jit(set_inputs_grad)?) }; - if let Some(code) = output.status.code() { - if code != 0 { - let args = command - .get_args() - .map(|s| s.to_str().unwrap()) - .collect::>() - .join(" "); - println!( - "{} {}", - command.get_program().to_os_string().to_str().unwrap(), - args - ); - println!("{}", String::from_utf8_lossy(&output.stderr)); - return Err(anyhow!( - "Error linking in runtime, returned error code {}", - code - )); - } - } - Ok(()) - } - - fn get_bitcode_filename(out: &str) -> String { - format!("{}.bc", out) - } - - fn get_object_filename(out: &str) -> String { - format!("{}.o", out) - } - - fn jit<'ctx, T>(name: &str, ee: &ExecutionEngine<'ctx>) -> Result> - where - T: UnsafeFunctionPointer, - { - let maybe_fn = unsafe { ee.get_function::(name) }; - match maybe_fn { - Ok(f) => Ok(f), - Err(err) => Err(anyhow!("Error during jit for {}: {}", name, err)), - } + Ok(Self { + module, + jit_functions: JitFunctions { + set_u0, + rhs, + mass, + calc_out, + calc_stop, + set_id, + get_dims, + set_inputs, + get_out, + }, + jit_grad_functions: JitGradFunctions { + set_u0_grad, + rhs_grad, + calc_out_grad, + set_inputs_grad, + }, + number_of_states, + number_of_parameters, + number_of_outputs, + has_mass, + }) } pub fn get_tensor_data<'a>(&self, name: &str, data: &'a [f64]) -> Option<&'a [f64]> { - let index = self.borrow_data_layout().get_data_index(name)?; - let nnz = self.borrow_data_layout().get_data_length(name)?; + let index = self.module.layout().get_data_index(name)?; + let nnz = self.module.layout().get_data_length(name)?; Some(&data[index..index + nnz]) } pub fn set_u0(&self, yy: &mut [f64], data: &mut [f64]) { - let number_of_states = *self.borrow_number_of_states(); - if yy.len() != number_of_states { - panic!("Expected {} states, got {}", number_of_states, yy.len()); + if yy.len() != self.number_of_states { + panic!( + "Expected {} states, got {}", + self.number_of_states, + yy.len() + ); } - self.with_data(|compiler| { - let yy_ptr = yy.as_mut_ptr(); - let data_ptr = data.as_mut_ptr(); - unsafe { - compiler.jit_functions.set_u0.call(data_ptr, yy_ptr); - } - }); + unsafe { (self.jit_functions.set_u0)(yy.as_mut_ptr(), data.as_mut_ptr()) }; } pub fn set_u0_grad( @@ -441,14 +164,17 @@ impl Compiler { data: &mut [f64], ddata: &mut [f64], ) { - let number_of_states = *self.borrow_number_of_states(); - if yy.len() != number_of_states { - panic!("Expected {} states, got {}", number_of_states, yy.len()); + if yy.len() != self.number_of_states { + panic!( + "Expected {} states, got {}", + self.number_of_states, + yy.len() + ); } - if dyy.len() != number_of_states { + if dyy.len() != self.number_of_states { panic!( "Expected {} states for dyy, got {}", - number_of_states, + self.number_of_states, dyy.len() ); } @@ -462,18 +188,14 @@ impl Compiler { ddata.len() ); } - self.with_data(|compiler| { - let yy_ptr = yy.as_mut_ptr(); - let data_ptr = data.as_mut_ptr(); - let dyy_ptr = dyy.as_mut_ptr(); - let ddata_ptr = ddata.as_mut_ptr(); - unsafe { - compiler - .jit_grad_functions - .set_u0_grad - .call(data_ptr, ddata_ptr, yy_ptr, dyy_ptr); - } - }); + unsafe { + (self.jit_grad_functions.set_u0_grad)( + yy.as_mut_ptr(), + dyy.as_mut_ptr(), + data.as_mut_ptr(), + ddata.as_mut_ptr(), + ) + }; } pub fn calc_stop(&self, t: f64, yy: &[f64], data: &mut [f64], stop: &mut [f64]) { @@ -487,81 +209,62 @@ impl Compiler { if stop.len() != n_stop { panic!("Expected {} stop, got {}", n_stop, stop.len()); } - self.with_data(|compiler| { - let yy_ptr = yy.as_ptr(); - let data_ptr = data.as_mut_ptr(); - let stop_ptr = stop.as_mut_ptr(); - unsafe { - compiler - .jit_functions - .calc_stop - .call(t, yy_ptr, data_ptr, stop_ptr); - } - }); + unsafe { + (self.jit_functions.calc_stop)(t, yy.as_ptr(), data.as_mut_ptr(), stop.as_mut_ptr()) + }; } pub fn rhs(&self, t: f64, yy: &[f64], data: &mut [f64], rr: &mut [f64]) { - let number_of_states = *self.borrow_number_of_states(); - if yy.len() != number_of_states { - panic!("Expected {} states, got {}", number_of_states, yy.len()); + if yy.len() != self.number_of_states { + panic!( + "Expected {} states, got {}", + self.number_of_states, + yy.len() + ); } - if rr.len() != number_of_states { + if rr.len() != self.number_of_states { panic!( "Expected {} residual states, got {}", - number_of_states, + self.number_of_states, rr.len() ); } if data.len() != self.data_len() { panic!("Expected {} data, got {}", self.data_len(), data.len()); } - self.with_data(|compiler| { - let yy_ptr = yy.as_ptr(); - let rr_ptr = rr.as_mut_ptr(); - let data_ptr = data.as_mut_ptr(); - unsafe { - compiler.jit_functions.rhs.call(t, yy_ptr, data_ptr, rr_ptr); - } - }); + unsafe { (self.jit_functions.rhs)(t, yy.as_ptr(), data.as_mut_ptr(), rr.as_mut_ptr()) }; } pub fn has_mass(&self) -> bool { - *self.borrow_has_mass() + self.has_mass } pub fn mass(&self, t: f64, yp: &[f64], data: &mut [f64], rr: &mut [f64]) { - if !self.borrow_has_mass() { + if !self.has_mass { panic!("Model does not have a mass function"); } - let number_of_states = *self.borrow_number_of_states(); - if yp.len() != number_of_states { - panic!("Expected {} states, got {}", number_of_states, yp.len()); + if yp.len() != self.number_of_states { + panic!( + "Expected {} states, got {}", + self.number_of_states, + yp.len() + ); } - if rr.len() != number_of_states { + if rr.len() != self.number_of_states { panic!( "Expected {} residual states, got {}", - number_of_states, + self.number_of_states, rr.len() ); } if data.len() != self.data_len() { panic!("Expected {} data, got {}", self.data_len(), data.len()); } - self.with_data(|compiler| { - let yp_ptr = yp.as_ptr(); - let rr_ptr = rr.as_mut_ptr(); - let data_ptr = data.as_mut_ptr(); - unsafe { - compiler - .jit_functions - .mass - .call(t, yp_ptr, data_ptr, rr_ptr); - } - }); + unsafe { (self.jit_functions.mass)(t, yp.as_ptr(), data.as_mut_ptr(), rr.as_mut_ptr()) }; } pub fn data_len(&self) -> usize { - self.with(|compiler| compiler.data_layout.data().len()) + self.module.layout().data().len() } pub fn get_new_data(&self) -> Vec { @@ -579,28 +282,31 @@ impl Compiler { rr: &mut [f64], drr: &mut [f64], ) { - let number_of_states = *self.borrow_number_of_states(); - if yy.len() != number_of_states { - panic!("Expected {} states, got {}", number_of_states, yy.len()); + if yy.len() != self.number_of_states { + panic!( + "Expected {} states, got {}", + self.number_of_states, + yy.len() + ); } - if rr.len() != number_of_states { + if rr.len() != self.number_of_states { panic!( "Expected {} residual states, got {}", - number_of_states, + self.number_of_states, rr.len() ); } - if dyy.len() != number_of_states { + if dyy.len() != self.number_of_states { panic!( "Expected {} states for dyy, got {}", - number_of_states, + self.number_of_states, dyy.len() ); } - if drr.len() != number_of_states { + if drr.len() != self.number_of_states { panic!( "Expected {} residual states for drr, got {}", - number_of_states, + self.number_of_states, drr.len() ); } @@ -614,37 +320,31 @@ impl Compiler { ddata.len() ); } - self.with_data(|compiler| { - let yy_ptr = yy.as_ptr(); - let rr_ptr = rr.as_mut_ptr(); - let dyy_ptr = dyy.as_ptr(); - let drr_ptr = drr.as_mut_ptr(); - let data_ptr = data.as_mut_ptr(); - let ddata_ptr = ddata.as_mut_ptr(); - unsafe { - compiler - .jit_grad_functions - .rhs_grad - .call(t, yy_ptr, dyy_ptr, data_ptr, ddata_ptr, rr_ptr, drr_ptr); - } - }); + unsafe { + (self.jit_grad_functions.rhs_grad)( + t, + yy.as_ptr(), + dyy.as_ptr(), + data.as_mut_ptr(), + ddata.as_mut_ptr(), + rr.as_mut_ptr(), + drr.as_mut_ptr(), + ) + }; } pub fn calc_out(&self, t: f64, yy: &[f64], data: &mut [f64]) { - let number_of_states = *self.borrow_number_of_states(); - if yy.len() != *self.borrow_number_of_states() { - panic!("Expected {} states, got {}", number_of_states, yy.len()); + if yy.len() != self.number_of_states { + panic!( + "Expected {} states, got {}", + self.number_of_states, + yy.len() + ); } if data.len() != self.data_len() { panic!("Expected {} data, got {}", self.data_len(), data.len()); } - self.with_data(|compiler| { - let yy_ptr = yy.as_ptr(); - let data_ptr = data.as_mut_ptr(); - unsafe { - compiler.jit_functions.calc_out.call(t, yy_ptr, data_ptr); - } - }); + unsafe { (self.jit_functions.calc_out)(t, yy.as_ptr(), data.as_mut_ptr()) }; } pub fn calc_out_grad( @@ -655,17 +355,20 @@ impl Compiler { data: &mut [f64], ddata: &mut [f64], ) { - let number_of_states = *self.borrow_number_of_states(); - if yy.len() != *self.borrow_number_of_states() { - panic!("Expected {} states, got {}", number_of_states, yy.len()); + if yy.len() != self.number_of_states { + panic!( + "Expected {} states, got {}", + self.number_of_states, + yy.len() + ); } if data.len() != self.data_len() { panic!("Expected {} data, got {}", self.data_len(), data.len()); } - if dyy.len() != *self.borrow_number_of_states() { + if dyy.len() != self.number_of_states { panic!( "Expected {} states for dyy, got {}", - number_of_states, + self.number_of_states, dyy.len() ); } @@ -676,18 +379,15 @@ impl Compiler { ddata.len() ); } - self.with_data(|compiler| { - let yy_ptr = yy.as_ptr(); - let data_ptr = data.as_mut_ptr(); - let dyy_ptr = dyy.as_ptr(); - let ddata_ptr = ddata.as_mut_ptr(); - unsafe { - compiler - .jit_grad_functions - .calc_out_grad - .call(t, yy_ptr, dyy_ptr, data_ptr, ddata_ptr); - } - }); + unsafe { + (self.jit_grad_functions.calc_out_grad)( + t, + yy.as_ptr(), + dyy.as_ptr(), + data.as_mut_ptr(), + ddata.as_mut_ptr(), + ) + }; } /// Get various dimensions of the model @@ -701,15 +401,15 @@ impl Compiler { let mut n_outputs = 0u32; let mut n_data = 0u32; let mut n_stop = 0u32; - self.with(|compiler| unsafe { - compiler.data.jit_functions.get_dims.call( + unsafe { + (self.jit_functions.get_dims)( &mut n_states, &mut n_inputs, &mut n_outputs, &mut n_data, &mut n_stop, - ); - }); + ) + }; ( n_states as usize, n_inputs as usize, @@ -727,15 +427,7 @@ impl Compiler { if data.len() != self.data_len() { panic!("Expected {} data, got {}", self.data_len(), data.len()); } - self.with_data(|compiler| { - let data_ptr = data.as_mut_ptr(); - unsafe { - compiler - .jit_functions - .set_inputs - .call(inputs.as_ptr(), data_ptr); - } - }); + unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr()) }; } pub fn set_inputs_grad( @@ -766,19 +458,14 @@ impl Compiler { ddata.len() ); } - self.with_data(|compiler| { - let data_ptr = data.as_mut_ptr(); - let ddata_ptr = ddata.as_mut_ptr(); - let dinputs_ptr = dinputs.as_ptr(); - unsafe { - compiler.jit_grad_functions.set_inputs_grad.call( - inputs.as_ptr(), - dinputs_ptr, - data_ptr, - ddata_ptr, - ); - } - }); + unsafe { + (self.jit_grad_functions.set_inputs_grad)( + inputs.as_ptr(), + dinputs.as_ptr(), + data.as_mut_ptr(), + ddata.as_mut_ptr(), + ) + }; } pub fn get_out(&self, data: &[f64]) -> &[f64] { @@ -790,16 +477,9 @@ impl Compiler { let mut tensor_data_len = 0u32; let tensor_data_ptr_ptr: *mut *mut f64 = &mut tensor_data_ptr; let tensor_data_len_ptr: *mut u32 = &mut tensor_data_len; - self.with(|compiler| { - let data_ptr = data.as_ptr(); - unsafe { - compiler.data.jit_functions.get_out.call( - data_ptr, - tensor_data_ptr_ptr, - tensor_data_len_ptr, - ); - } - }); + unsafe { + (self.jit_functions.get_out)(data.as_ptr(), tensor_data_ptr_ptr, tensor_data_len_ptr) + }; assert!(tensor_data_len as usize == n_outputs); unsafe { std::slice::from_raw_parts(tensor_data_ptr, tensor_data_len as usize) } } @@ -809,126 +489,55 @@ impl Compiler { if n_states != id.len() { panic!("Expected {} states, got {}", n_states, id.len()); } - self.with_data(|compiler| { - unsafe { - compiler.jit_functions.set_id.call(id.as_mut_ptr()); - }; - }); - } - - fn get_native_machine() -> Result { - Target::initialize_native(&InitializationConfig::default()) - .map_err(|e| anyhow!("{}", e))?; - let opt = OptimizationLevel::Default; - let reloc = RelocMode::Default; - let model = CodeModel::Default; - let target_triple = TargetMachine::get_default_triple(); - let target = Target::from_triple(&target_triple).unwrap(); - let target_machine = target - .create_target_machine( - &target_triple, - TargetMachine::get_host_cpu_name().to_str().unwrap(), - TargetMachine::get_host_cpu_features().to_str().unwrap(), - opt, - reloc, - model, - ) - .unwrap(); - Ok(target_machine) - } - - fn get_wasm_machine() -> Result { - Target::initialize_webassembly(&InitializationConfig::default()); - let opt = OptimizationLevel::Default; - let reloc = RelocMode::Default; - let model = CodeModel::Default; - let target_triple = TargetTriple::create("wasm32-unknown-emscripten"); - let target = Target::from_triple(&target_triple).unwrap(); - let target_machine = target - .create_target_machine(&target_triple, "generic", "", opt, reloc, model) - .unwrap(); - Ok(target_machine) - } - - pub fn write_bitcode_to_path(&self, path: &Path) -> Result<()> { - self.with_data(|data| { - let result = data.codegen.module().write_bitcode_to_path(path); - if result { - Ok(()) - } else { - Err(anyhow!("Error writing bitcode to path")) - } - }) - } - - pub fn write_object_file(&self, path: &Path) -> Result<()> { - let target_machine = Compiler::get_native_machine()?; - self.with_data(|data| { - target_machine - .write_to_file(data.codegen.module(), FileType::Object, path) - .map_err(|e| anyhow::anyhow!("Error writing object file: {:?}", e)) - }) - } - - pub fn write_wasm_object_file(&self, path: &Path) -> Result<()> { - let target_machine = Compiler::get_wasm_machine()?; - self.with_data(|data| { - target_machine - .write_to_file(data.codegen.module(), FileType::Object, path) - .map_err(|e| anyhow::anyhow!("Error writing object file: {:?}", e)) - }) + unsafe { (self.jit_functions.set_id)(id.as_mut_ptr()) }; } pub fn number_of_states(&self) -> usize { - *self.borrow_number_of_states() + self.number_of_states } pub fn number_of_parameters(&self) -> usize { - *self.borrow_number_of_parameters() + self.number_of_parameters } pub fn number_of_outputs(&self) -> usize { - *self.borrow_number_of_outputs() + self.number_of_outputs } } #[cfg(test)] mod tests { - use crate::{ - continuous::ModelInfo, - parser::{parse_ds_string, parse_ms_string}, - }; + use crate::{parser::parse_ds_string, CraneliftModule}; use approx::assert_relative_eq; use super::*; + #[cfg(feature = "llvm")] #[test] - fn test_object_file() { + fn test_from_discrete_str_llvm() { + use crate::execution::llvm::codegen::LlvmModule; let text = " - model logistic_growth(r -> NonNegative, k -> NonNegative, y(t), z(t)) { - dot(y) = r * y * (1 - y / k) - y(0) = 1.0 - z = 2 * y - } + u { y = 1 } + F { -y } + out { y } "; - let models = parse_ms_string(text).unwrap(); - let model_info = ModelInfo::build("logistic_growth", &models).unwrap(); - assert_eq!(model_info.errors.len(), 0); - let discrete_model = DiscreteModel::from(&model_info); - let object = - Compiler::from_discrete_model(&discrete_model, "test_output/compiler_test_object_file") - .unwrap(); - let path = Path::new("main.o"); - object.write_object_file(path).unwrap(); + let compiler = Compiler::::from_discrete_str(text).unwrap(); + let mut u0 = vec![0.]; + let mut res = vec![0.]; + let mut data = compiler.get_new_data(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + assert_relative_eq!(u0.as_slice(), vec![1.].as_slice()); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + assert_relative_eq!(res.as_slice(), vec![-1.].as_slice()); } #[test] - fn test_from_discrete_str() { + fn test_from_discrete_str_cranelift() { let text = " u { y = 1 } F { -y } out { y } "; - let compiler = Compiler::from_discrete_str(text).unwrap(); + let compiler = Compiler::::from_discrete_str(text).unwrap(); let mut u0 = vec![0.]; let mut res = vec![0.]; let mut data = compiler.get_new_data(); @@ -939,7 +548,17 @@ mod tests { } #[test] - fn test_stop() { + fn test_stop_cranelift() { + test_stop::(); + } + + #[cfg(feature = "llvm")] + #[test] + fn test_stop_llvm() { + test_stop::(); + } + + fn test_stop() { let full_text = " u_i { y = 1, @@ -962,9 +581,7 @@ mod tests { "; let model = parse_ds_string(full_text).unwrap(); let discrete_model = DiscreteModel::build("$name", &model).unwrap(); - let compiler = - Compiler::from_discrete_model(&discrete_model, "test_output/compiler_test_stop") - .unwrap(); + let compiler = Compiler::::from_discrete_model(&discrete_model).unwrap(); let mut u0 = vec![1.]; let mut res = vec![0.]; let mut stop = vec![0.]; @@ -976,7 +593,7 @@ mod tests { assert_eq!(stop.len(), 1); } - fn tensor_test_common(text: &str, tmp_loc: &str, tensor_name: &str) -> Vec> { + fn tensor_test_common(text: &str, tensor_name: &str) -> Vec> { let full_text = format!( " {} @@ -990,7 +607,7 @@ mod tests { panic!("{}", e.as_error_message(full_text.as_str())); } }; - let compiler = Compiler::from_discrete_model(&discrete_model, tmp_loc).unwrap(); + let compiler = Compiler::::from_discrete_model(&discrete_model).unwrap(); let mut u0 = vec![1.]; let mut res = vec![0.]; let mut data = compiler.get_new_data(); @@ -1072,8 +689,15 @@ mod tests { y, }} ", $text); - let tmp_loc = format!("test_output/compiler_tensor_test_{}", stringify!($name)); - let results = tensor_test_common(full_text.as_str(), tmp_loc.as_str(), $tensor_name); + + #[cfg(feature = "llvm")] + { + use crate::execution::llvm::codegen::LlvmModule; + let results = tensor_test_common::(full_text.as_str(), $tensor_name); + assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); + } + + let results = tensor_test_common::(full_text.as_str(), $tensor_name); assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); } )* @@ -1187,8 +811,15 @@ mod tests { y, }} ", $text); - let tmp_loc = format!("test_output/compiler_tensor_grad_test_{}", stringify!($name)); - let results = tensor_test_common(full_text.as_str(), tmp_loc.as_str(), $tensor_name); + + #[cfg(feature = "llvm")] + { + use crate::execution::llvm::codegen::LlvmModule; + let results = tensor_test_common::(full_text.as_str(), $tensor_name); + assert_relative_eq!(results[1].as_slice(), $expected_value.as_slice()); + } + + let results = tensor_test_common::(full_text.as_str(), $tensor_name); assert_relative_eq!(results[1].as_slice(), $expected_value.as_slice()); } )* @@ -1205,7 +836,17 @@ mod tests { } #[test] - fn test_repeated_grad() { + fn test_repeated_grad_cranelift() { + test_repeated_grad_common::(); + } + + #[cfg(feature = "llvm")] + #[test] + fn test_repeated_grad_llvm() { + test_repeated_grad_common::(); + } + + fn test_repeated_grad_common() { let full_text = " in = [p] p { @@ -1237,11 +878,7 @@ mod tests { panic!("{}", e.as_error_message(full_text)); } }; - let compiler = Compiler::from_discrete_model( - &discrete_model, - "test_output/compiler_test_repeated_grad", - ) - .unwrap(); + let compiler = Compiler::::from_discrete_model(&discrete_model).unwrap(); let mut u0 = vec![1.]; let mut du0 = vec![1.]; let mut res = vec![0.]; @@ -1250,7 +887,7 @@ mod tests { let mut ddata = compiler.get_new_data(); let (_n_states, n_inputs, _n_outputs, _n_data, _n_stop) = compiler.get_dims(); - for _ in 0..3 { + for _i in 0..3 { let inputs = vec![2.; n_inputs]; let dinputs = vec![1.; n_inputs]; compiler.set_inputs_grad( @@ -1309,11 +946,7 @@ mod tests { "; let model = parse_ds_string(full_text).unwrap(); let discrete_model = DiscreteModel::build("$name", &model).unwrap(); - let compiler = Compiler::from_discrete_model( - &discrete_model, - "test_output/compiler_test_additional_functions", - ) - .unwrap(); + let compiler = Compiler::::from_discrete_model(&discrete_model).unwrap(); let (n_states, n_inputs, n_outputs, n_data, _n_stop) = compiler.get_dims(); assert_eq!(n_states, 2); assert_eq!(n_inputs, 1); diff --git a/src/execution/cranelift/codegen.rs b/src/execution/cranelift/codegen.rs new file mode 100644 index 0000000..62e1807 --- /dev/null +++ b/src/execution/cranelift/codegen.rs @@ -0,0 +1,1680 @@ +use anyhow::{anyhow, Ok, Result}; +use codegen::ir::{FuncRef, GlobalValue, StackSlot}; +use cranelift::prelude::*; +use cranelift_jit::{JITBuilder, JITModule}; +use cranelift_module::{DataDescription, DataId, FuncId, Linkage, Module}; +use std::collections::HashMap; +use std::iter::zip; +use target_lexicon::{Endianness, PointerWidth, Triple}; + +use crate::ast::{Ast, AstKind}; +use crate::discretise::{DiscreteModel, Tensor, TensorBlock}; +use crate::execution::module::CodegenModule; +use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo}; + +pub struct CraneliftModule { + /// The function builder context, which is reused across multiple + /// FunctionBuilder instances. + builder_context: FunctionBuilderContext, + + /// The main Cranelift context, which holds the state for codegen. Cranelift + /// separates this from `Module` to allow for parallel compilation, with a + /// context per thread, though this isn't in the simple demo here. + ctx: codegen::Context, + + /// The data description, which is to data objects what `ctx` is to functions. + //data_description: DataDescription, + + /// The module, with the jit backend, which manages the JIT'd + /// functions. + module: JITModule, + + layout: DataLayout, + + indices_id: DataId, + + //triple: Triple, + int_type: types::Type, + real_type: types::Type, + real_ptr_type: types::Type, + int_ptr_type: types::Type, +} + +impl CraneliftModule { + fn declare_function(&mut self, name: &str) -> Result { + // Next, declare the function to jit. Functions must be declared + // before they can be called, or defined. + // + // TODO: This may be an area where the API should be streamlined; should + // we have a version of `declare_function` that automatically declares + // the function? + let id = self + .module + .declare_function(name, Linkage::Export, &self.ctx.func.signature)?; + + //println!("Declared function: {}", name); + //println!("IR:\n{}", self.ctx.func); + + // Define the function to jit. This finishes compilation, although + // there may be outstanding relocations to perform. Currently, jit + // cannot finish relocations until all functions to be called are + // defined. For this toy demo for now, we'll just finalize the + // function below. + self.module.define_function(id, &mut self.ctx)?; + + // Now that compilation is finished, we can clear out the context state. + self.module.clear_context(&mut self.ctx); + + Ok(id) + } +} + +impl CodegenModule for CraneliftModule { + type FuncId = FuncId; + + fn compile_calc_out_grad( + &mut self, + _func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result { + let arg_types = &[ + self.real_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + ]; + let arg_names = &["t", "u", "du", "data", "ddata"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + codegen.jit_compile_tensor(model.out(), None, false)?; + + codegen.jit_compile_tensor(model.out(), None, true)?; + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + + self.declare_function("calc_out_grad") + } + + fn compile_rhs_grad( + &mut self, + _func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result { + let arg_types = &[ + self.real_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + ]; + let arg_names = &["t", "u", "du", "data", "ddata", "rr", "drr"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + // calculate time dependant definitions + for tensor in model.time_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_tensor(tensor, None, true)?; + } + + // TODO: could split state dep defns into before and after F + for a in model.state_dep_defns() { + codegen.jit_compile_tensor(a, None, false)?; + codegen.jit_compile_tensor(a, None, true)?; + } + + // F + let res = *codegen.variables.get("rr").unwrap(); + codegen.jit_compile_tensor(model.rhs(), Some(res), false)?; + let res = *codegen.variables.get("drr").unwrap(); + codegen.jit_compile_tensor(model.rhs(), Some(res), true)?; + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("rhs_grad") + } + + fn compile_set_inputs_grad( + &mut self, + _func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result { + let arg_types = &[ + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + ]; + let arg_names = &["inputs", "dinputs", "data", "ddata"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + let base_data_ptr = codegen.variables.get("data").unwrap(); + let base_data_ptr = codegen.builder.use_var(*base_data_ptr); + codegen.jit_compile_set_inputs(model, base_data_ptr, false); + + let base_data_ptr = codegen.variables.get("ddata").unwrap(); + let base_data_ptr = codegen.builder.use_var(*base_data_ptr); + codegen.jit_compile_set_inputs(model, base_data_ptr, true); + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("set_inputs_grad") + } + + fn compile_set_u0_grad( + &mut self, + _func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result { + let arg_types = &[ + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + ]; + let arg_names = &["u0", "du0", "data", "ddata"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + for a in model.time_indep_defns() { + codegen.jit_compile_tensor(a, None, false)?; + codegen.jit_compile_tensor(a, None, true)?; + } + + codegen.jit_compile_tensor( + model.state(), + Some(*codegen.variables.get("u0").unwrap()), + false, + )?; + codegen.jit_compile_tensor( + model.state(), + Some(*codegen.variables.get("du0").unwrap()), + true, + )?; + + // Emit the return instruction. + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + + self.declare_function("u0_grad") + } + + fn jit(&mut self, id: Self::FuncId) -> Result<*const u8> { + // We can now retrieve a pointer to the machine code. + let code = self.module.get_finalized_function(id); + Ok(code) + } + + fn layout(&self) -> &DataLayout { + &self.layout + } + + fn post_autodiff_optimisation(&mut self) -> Result<()> { + // Finalize the functions which we just defined, which resolves any + // outstanding relocations (patching in addresses, now that they're + // available). + self.module.finalize_definitions()?; + Ok(()) + } + + fn pre_autodiff_optimisation(&mut self) -> Result<()> { + Ok(()) + } + + fn new(triple: Triple, model: &DiscreteModel) -> Result { + let mut flag_builder = settings::builder(); + flag_builder.set("use_colocated_libcalls", "false").unwrap(); + flag_builder.set("is_pic", "false").unwrap(); + flag_builder.set("opt_level", "speed").unwrap(); + let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { + panic!("host machine is not supported: {}", msg); + }); + let isa = isa_builder + .finish(settings::Flags::new(flag_builder)) + .unwrap(); + let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); + + // add supported external rust functions + for func in crate::execution::functions::FUNCTIONS.iter() { + builder.symbol(func.0, func.1 as *const u8); + builder.symbol( + CraneliftCodeGen::get_function_name(func.0, true), + func.2 as *const u8, + ); + } + for func in crate::execution::functions::TWO_ARG_FUNCTIONS.iter() { + builder.symbol(func.0, func.1 as *const u8); + builder.symbol( + CraneliftCodeGen::get_function_name(func.0, true), + func.2 as *const u8, + ); + } + + let mut module = JITModule::new(builder); + + let ptr_type = match triple.pointer_width().unwrap() { + PointerWidth::U16 => types::I16, + PointerWidth::U32 => types::I32, + PointerWidth::U64 => types::I64, + }; + + let layout = DataLayout::new(model); + + // write indices data as a global data object + // convect the indices to bytes + let int_type = types::I32; + let real_type = types::F64; + let mut vec8: Vec = vec![]; + for elem in layout.indices() { + // convert indices to i64 + if int_type == types::I64 { + let elemi64 = i64::from(*elem); + let conv = match triple.endianness().unwrap() { + Endianness::Little => elemi64.to_le_bytes(), + Endianness::Big => elemi64.to_be_bytes(), + }; + vec8.extend(conv.into_iter()); + } else { + let conv = match triple.endianness().unwrap() { + Endianness::Little => elem.to_le_bytes(), + Endianness::Big => elem.to_be_bytes(), + }; + vec8.extend(conv.into_iter()); + }; + } + + // put the indices data into a DataDescription + let mut data_description = DataDescription::new(); + data_description.define(vec8.into_boxed_slice()); + let indices_id = module.declare_data("indices", Linkage::Local, false, false)?; + module.define_data(indices_id, &data_description)?; + + Ok(Self { + builder_context: FunctionBuilderContext::new(), + ctx: module.make_context(), + module, + indices_id, + int_type, + real_type, + real_ptr_type: ptr_type, + int_ptr_type: ptr_type, + layout, + }) + } + + fn compile_set_u0(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[self.real_ptr_type, self.real_ptr_type]; + let arg_names = &["u0", "data"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + for a in model.time_indep_defns() { + codegen.jit_compile_tensor(a, None, false)?; + } + + codegen.jit_compile_tensor( + model.state(), + Some(*codegen.variables.get("u0").unwrap()), + false, + )?; + + // Emit the return instruction. + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + + self.declare_function("u0") + } + + fn compile_calc_out(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[self.real_type, self.real_ptr_type, self.real_ptr_type]; + let arg_names = &["t", "u", "data"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + codegen.jit_compile_tensor(model.out(), None, false)?; + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + + self.declare_function("calc_out") + } + + fn compile_calc_stop(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[ + self.real_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + ]; + let arg_names = &["t", "u", "data", "root"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + if let Some(stop) = model.stop() { + let root = *codegen.variables.get("root").unwrap(); + codegen.jit_compile_tensor(stop, Some(root), false)?; + } + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("calc_stop") + } + + fn compile_rhs(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[ + self.real_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + ]; + let arg_names = &["t", "u", "data", "rr"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + // calculate time dependant definitions + for tensor in model.time_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + } + + // TODO: could split state dep defns into before and after F + for a in model.state_dep_defns() { + codegen.jit_compile_tensor(a, None, false)?; + } + + // F + let res = *codegen.variables.get("rr").unwrap(); + codegen.jit_compile_tensor(model.rhs(), Some(res), false)?; + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("rhs") + } + + fn compile_mass(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[ + self.real_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + ]; + let arg_names = &["t", "dudt", "data", "rr"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + // only put code in this function if we have a state_dot and lhs + if model.state_dot().is_some() && model.lhs().is_some() { + // calculate time dependant definitions + for tensor in model.time_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + } + + for a in model.dstate_dep_defns() { + codegen.jit_compile_tensor(a, None, false)?; + } + + // mass + let lhs = model.lhs().unwrap(); + let res = codegen.variables.get("rr").unwrap(); + codegen.jit_compile_tensor(lhs, Some(*res), false)?; + } + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("mass") + } + + fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[ + self.int_ptr_type, + self.int_ptr_type, + self.int_ptr_type, + self.int_ptr_type, + self.int_ptr_type, + ]; + let arg_names = &["states", "inputs", "outputs", "data", "stop"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + let number_of_states = i64::try_from(model.state().nnz()).unwrap(); + let number_of_inputs = + i64::try_from(model.inputs().iter().fold(0, |acc, x| acc + x.nnz())).unwrap(); + let number_of_outputs = i64::try_from(model.out().nnz()).unwrap(); + let number_of_stop = if let Some(stop) = model.stop() { + i64::try_from(stop.nnz()).unwrap() + } else { + 0 + }; + let data_len = i64::try_from(codegen.layout.data().len()).unwrap(); + + for (val, name) in [ + (number_of_states, "states"), + (number_of_inputs, "inputs"), + (number_of_outputs, "outputs"), + (data_len, "data"), + (number_of_stop, "stop"), + ] { + let val = codegen.builder.ins().iconst(codegen.int_type, val); + let ptr = codegen.variables.get(name).unwrap(); + let ptr = codegen.builder.use_var(*ptr); + codegen.builder.ins().store(codegen.mem_flags, val, ptr, 0); + } + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("gen_dims") + } + + fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result { + let arg_types = &[self.real_ptr_type, self.real_ptr_type, self.int_ptr_type]; + let arg_names = &["data", "tensor_data", "tensor_size"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + let tensor_ptr = codegen.variables.get(name).unwrap(); + let tensor_ptr = codegen.builder.use_var(*tensor_ptr); + + let tensor_size = i64::try_from(codegen.layout.get_layout(name).unwrap().nnz()).unwrap(); + let tensor_size = codegen.builder.ins().iconst(codegen.int_type, tensor_size); + + for (val, name) in [(tensor_ptr, "tensor_data"), (tensor_size, "tensor_size")] { + let ptr = codegen.variables.get(name).unwrap(); + let ptr = codegen.builder.use_var(*ptr); + codegen.builder.ins().store(codegen.mem_flags, val, ptr, 0); + } + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("get_tensor") + } + + fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[self.real_ptr_type, self.real_ptr_type]; + let arg_names = &["inputs", "data"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + let base_data_ptr = codegen.variables.get("data").unwrap(); + let base_data_ptr = codegen.builder.use_var(*base_data_ptr); + codegen.jit_compile_set_inputs(model, base_data_ptr, false); + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("set_inputs") + } + + fn compile_set_id(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[self.real_ptr_type]; + let arg_names = &["id"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + let mut id_index = 0usize; + for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) { + // loop thru the elements of this state blk and set the corresponding elements of id + let id_start_index = codegen + .builder + .ins() + .iconst(codegen.int_type, i64::try_from(id_index).unwrap()); + let blk_start_index = codegen.builder.ins().iconst(codegen.int_type, 0); + + let blk_block = codegen.builder.create_block(); + let curr_blk_index = codegen + .builder + .append_block_param(blk_block, codegen.int_type); + codegen.builder.ins().jump(blk_block, &[blk_start_index]); + + codegen.builder.switch_to_block(blk_block); + + // loop body - copy value from inputs to data + let input_id_ptr = codegen.variables.get("id").unwrap(); + let input_id_ptr = codegen.builder.use_var(*input_id_ptr); + let curr_id_index = codegen.builder.ins().iadd(id_start_index, curr_blk_index); + let indexed_id_ptr = + codegen.ptr_add_offset(codegen.real_type, input_id_ptr, curr_id_index); + + let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 }; + let is_algebraic_value = codegen.fconst(is_algebraic_float); + codegen + .builder + .ins() + .store(codegen.mem_flags, is_algebraic_value, indexed_id_ptr, 0); + + // increment loop index + let one = codegen.builder.ins().iconst(codegen.int_type, 1); + let next_index = codegen.builder.ins().iadd(curr_blk_index, one); + + let loop_while = codegen.builder.ins().icmp_imm( + IntCC::UnsignedLessThan, + next_index, + i64::try_from(blk.nnz()).unwrap(), + ); + let post_block = codegen.builder.create_block(); + codegen + .builder + .ins() + .brif(loop_while, blk_block, &[next_index], post_block, &[]); + codegen.builder.seal_block(blk_block); + codegen.builder.seal_block(post_block); + codegen.builder.switch_to_block(post_block); + + // get ready for next blk + id_index += blk.nnz(); + } + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("set_id") + } +} + +/// A collection of state used for translating from toy-language AST nodes +/// into Cranelift IR. +struct CraneliftCodeGen<'a> { + int_type: types::Type, + real_type: types::Type, + real_ptr_type: types::Type, + int_ptr_type: types::Type, + builder: FunctionBuilder<'a>, + module: &'a mut JITModule, + tensor_ptr: Option, + variables: HashMap, + mem_flags: MemFlags, + functions: HashMap, + layout: &'a DataLayout, + indices: GlobalValue, +} + +impl<'ctx> CraneliftCodeGen<'ctx> { + fn fconst(&mut self, value: f64) -> Value { + match self.real_type { + types::F32 => self.builder.ins().f32const(value as f32), + types::F64 => self.builder.ins().f64const(value), + _ => panic!("unexpected real type"), + } + } + fn ptr_add_offset_i64(&mut self, elmt_ty: types::Type, ptr: Value, offset: i64) -> Value { + // both ptr types are the same, so just use real_ptr_type + let ptr_ty = self.real_ptr_type; + let width = elmt_ty.bytes() as i64; + let offset_bytes = self.builder.ins().iconst(ptr_ty, offset * width); + self.builder.ins().iadd(ptr, offset_bytes) + } + + fn ptr_add_offset(&mut self, elmt_ty: types::Type, ptr: Value, offset: Value) -> Value { + let width = elmt_ty.bytes() as i64; + // both ptr types are the same, so just use real_ptr_type + let ptr_ty = self.real_ptr_type; + + let width_value = self.builder.ins().iconst(ptr_ty, width); + let offset_ptr = if self.int_type != ptr_ty { + self.builder.ins().sextend(ptr_ty, offset) + } else { + offset + }; + let offset_bytes = self.builder.ins().imul(offset_ptr, width_value); + self.builder.ins().iadd(ptr, offset_bytes) + } + fn jit_compile_expr( + &mut self, + name: &str, + expr: &Ast, + index: &[Value], + elmt: &TensorBlock, + expr_index: Option, + ) -> Result { + let name = elmt.name().unwrap_or(name); + match &expr.kind { + AstKind::Binop(binop) => { + let lhs = + self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?; + let rhs = + self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?; + match binop.op { + '*' => Ok(self.builder.ins().fmul(lhs, rhs)), + '/' => Ok(self.builder.ins().fdiv(lhs, rhs)), + '-' => Ok(self.builder.ins().fsub(lhs, rhs)), + '+' => Ok(self.builder.ins().fadd(lhs, rhs)), + unknown => Err(anyhow!("unknown binop op '{}'", unknown)), + } + } + AstKind::Monop(monop) => { + let child = + self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?; + match monop.op { + '-' => Ok(self.builder.ins().fneg(child)), + unknown => Err(anyhow!("unknown monop op '{}'", unknown)), + } + } + AstKind::Call(call) => match self.get_function(call.fn_name, call.is_tangent) { + Some(function) => { + let mut args = Vec::new(); + for arg in call.args.iter() { + let arg_val = + self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?; + args.push(arg_val); + } + let call = self.builder.ins().call(function, &args); + let ret_value = self.builder.inst_results(call)[0]; + Ok(ret_value) + } + None => Err(anyhow!("unknown function call '{}'", call.fn_name)), + }, + AstKind::CallArg(arg) => { + self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index) + } + AstKind::Number(value) => Ok(self.fconst(*value)), + AstKind::Name(iname) => { + let ptr = if iname.is_tangent { + let name = self.get_tangent_tensor_name(iname.name); + self.builder + .use_var(*self.variables.get(name.as_str()).unwrap()) + } else { + self.builder + .use_var(*self.variables.get(iname.name).unwrap()) + }; + // arg t is a special case (not a ptr) + if iname.name == "t" { + return Ok(ptr); + } + let layout = self.layout.get_layout(iname.name).unwrap(); + let iname_elmt_index = if layout.is_dense() { + // permute indices based on the index chars of this tensor + let mut no_transform = true; + let mut iname_index = Vec::new(); + for (i, c) in iname.indices.iter().enumerate() { + // find the position index of this index char in the tensor's index chars, + // if it's not found then it must be a contraction index so is at the end + let pi = elmt + .indices() + .iter() + .position(|x| x == c) + .unwrap_or(elmt.indices().len()); + iname_index.push(index[pi]); + no_transform = no_transform && pi == i; + } + // calculate the element index using iname_index and the shape of the tensor + // TODO: can we optimise this by using expr_index, and also including elmt_index? + if !iname_index.is_empty() { + let mut iname_elmt_index = *iname_index.last().unwrap(); + let mut stride = 1u64; + for i in (0..iname_index.len() - 1).rev() { + let iname_i = iname_index[i]; + let shapei: u64 = layout.shape()[i + 1].try_into().unwrap(); + stride *= shapei; + let stride_intval = self + .builder + .ins() + .iconst(self.int_type, i64::try_from(stride).unwrap()); + let stride_mul_i = self.builder.ins().imul(stride_intval, iname_i); + iname_elmt_index = + self.builder.ins().iadd(iname_elmt_index, stride_mul_i); + } + Some(iname_elmt_index) + } else { + None + } + } else if layout.is_sparse() || layout.is_diagonal() { + // must have come from jit_compile_sparse_block, so we can just use the elmt_index + // must have come from jit_compile_diagonal_block, so we can just use the elmt_index + expr_index + } else { + panic!("unexpected layout"); + }; + let value_ptr = match iname_elmt_index { + Some(offset) => self.ptr_add_offset(self.real_type, ptr, offset), + None => ptr, + }; + Ok(self + .builder + .ins() + .load(self.real_type, self.mem_flags, value_ptr, 0)) + } + AstKind::NamedGradient(name) => { + let name_str = name.to_string(); + let ptr = self + .builder + .use_var(*self.variables.get(name_str.as_str()).unwrap()); + Ok(self + .builder + .ins() + .load(self.real_type, self.mem_flags, ptr, 0)) + } + AstKind::Index(_) => todo!(), + AstKind::Slice(_) => todo!(), + AstKind::Integer(_) => todo!(), + _ => panic!("unexprected astkind"), + } + } + + fn get_function_name(name: &str, is_tangent: bool) -> String { + if is_tangent { + format!("{}__tangent__", name) + } else { + name.to_owned() + } + } + + fn get_function(&mut self, base_name: &str, is_tangent: bool) -> Option { + let name = Self::get_function_name(base_name, is_tangent); + match self.functions.get(name.as_str()) { + Some(&func) => Some(func), + None => match crate::execution::functions::function_num_args(base_name, is_tangent) { + Some(num_args) => { + let mut sig = self.module.make_signature(); + for _ in 0..num_args { + sig.params.push(AbiParam::new(self.real_type)); + } + sig.returns.push(AbiParam::new(self.real_type)); + let callee = self + .module + .declare_function(name.as_str(), Linkage::Import, &sig) + .expect("problem declaring function"); + let function = self.module.declare_func_in_func(callee, self.builder.func); + self.functions.insert(name, function); + Some(function) + } + None => None, + }, + } + } + + fn jit_compile_tensor( + &mut self, + a: &Tensor, + var: Option, + is_tangent: bool, + ) -> Result { + // set up the tensor storage pointer and index into this data + if let Some(var) = var { + self.tensor_ptr = Some(self.builder.use_var(var)); + } else { + let name = if is_tangent { + self.get_tangent_tensor_name(a.name()) + } else { + a.name().to_owned() + }; + let res_ptr_var = *self + .variables + .get(name.as_str()) + .unwrap_or_else(|| panic!("tensor {} not defined", a.name())); + let res_ptr = self.builder.use_var(res_ptr_var); + self.tensor_ptr = Some(res_ptr); + } + + // treat scalar as a special case + if a.rank() == 0 { + let elmt = a.elmts().first().unwrap(); + let expr = if is_tangent { + elmt.tangent_expr() + } else { + elmt.expr() + }; + let float_value = self.jit_compile_expr(a.name(), expr, &[], elmt, None)?; + self.builder + .ins() + .store(self.mem_flags, float_value, self.tensor_ptr.unwrap(), 0); + return Ok(self.tensor_ptr.unwrap()); + } + + for (i, blk) in a.elmts().iter().enumerate() { + let default = format!("{}-{}", a.name(), i); + let name = blk.name().unwrap_or(default.as_str()); + self.jit_compile_block(name, a, blk, is_tangent)?; + } + Ok(self.tensor_ptr.unwrap()) + } + + fn jit_compile_block( + &mut self, + name: &str, + tensor: &Tensor, + elmt: &TensorBlock, + is_tangent: bool, + ) -> Result<()> { + let translation = Translation::new( + elmt.expr_layout(), + elmt.layout(), + elmt.start(), + tensor.layout_ptr(), + ); + + if elmt.expr_layout().is_dense() { + self.jit_compile_dense_block(name, elmt, &translation, is_tangent) + } else if elmt.expr_layout().is_diagonal() { + self.jit_compile_diagonal_block(name, elmt, &translation, is_tangent) + } else if elmt.expr_layout().is_sparse() { + match translation.source { + TranslationFrom::SparseContraction { .. } => { + self.jit_compile_sparse_contraction_block(name, elmt, &translation, is_tangent) + } + _ => self.jit_compile_sparse_block(name, elmt, &translation, is_tangent), + } + } else { + return Err(anyhow!( + "unsupported block layout: {:?}", + elmt.expr_layout() + )); + } + } + + fn decl_stack_slot(&mut self, ty: Type, val: Option) -> StackSlot { + let data = StackSlotData::new(StackSlotKind::ExplicitSlot, ty.bytes(), 0); + let ss = self.builder.create_sized_stack_slot(data); + if let Some(val) = val { + self.builder.ins().stack_store(val, ss, 0); + } + ss + } + + // for dense blocks we can loop through the nested loops to calculate the index, then we compile the expression passing in this index + fn jit_compile_dense_block( + &mut self, + name: &str, + elmt: &TensorBlock, + translation: &Translation, + is_tangent: bool, + ) -> Result<()> { + let int_type = self.int_type; + + let mut preblock = self.builder.current_block().unwrap(); + let expr_rank = elmt.expr_layout().rank(); + let expr_shape = elmt + .expr_layout() + .shape() + .mapv(|n| i64::try_from(n).unwrap()); + let one = self.builder.ins().iconst(int_type, 1); + let zero = self.builder.ins().iconst(int_type, 0); + + let expr_index_var = self.decl_stack_slot(self.int_type, Some(zero)); + let elmt_index_var = self.decl_stack_slot(self.int_type, Some(zero)); + + // setup indices, loop through the nested loops + let mut indices = Vec::new(); + let mut blocks = Vec::new(); + + // allocate the contract sum if needed + let (contract_sum, contract_by) = if let TranslationFrom::DenseContraction { + contract_by, + contract_len: _, + } = translation.source + { + ( + Some(self.decl_stack_slot(self.real_type, None)), + contract_by, + ) + } else { + (None, 0) + }; + + for i in 0..expr_rank { + let block = self.builder.create_block(); + let curr_index = self.builder.append_block_param(block, self.int_type); + self.builder.ins().jump(block, &[zero]); + self.builder.switch_to_block(block); + + if i == expr_rank - contract_by - 1 && contract_sum.is_some() { + let fzero = self.fconst(0.0); + self.builder + .ins() + .stack_store(fzero, contract_sum.unwrap(), 0); + } + + indices.push(curr_index); + blocks.push(block); + preblock = block; + } + + let elmt_index = self + .builder + .ins() + .stack_load(self.int_type, elmt_index_var, 0); + + // load and increment the expression index + let expr_index = self + .builder + .ins() + .stack_load(self.int_type, expr_index_var, 0); + let next_expr_index = self.builder.ins().iadd(expr_index, one); + self.builder + .ins() + .stack_store(next_expr_index, expr_index_var, 0); + + let expr = if is_tangent { + elmt.tangent_expr() + } else { + elmt.expr() + }; + let float_value = + self.jit_compile_expr(name, expr, indices.as_slice(), elmt, Some(expr_index))?; + + if contract_sum.is_some() { + let contract_sum_value = + self.builder + .ins() + .stack_load(self.real_type, contract_sum.unwrap(), 0); + let new_contract_sum_value = self.builder.ins().fadd(contract_sum_value, float_value); + self.builder + .ins() + .stack_store(new_contract_sum_value, contract_sum.unwrap(), 0); + } else { + self.jit_compile_broadcast_and_store( + name, + elmt, + float_value, + expr_index, + translation, + preblock, + )?; + let next_elmt_index = self.builder.ins().iadd(elmt_index, one); + self.builder + .ins() + .stack_store(next_elmt_index, elmt_index_var, 0); + } + + // unwind the nested loops + for i in (0..expr_rank).rev() { + // update and store contract sum + if i == expr_rank - contract_by - 1 && contract_sum.is_some() { + let next_elmt_index = self.builder.ins().iadd(elmt_index, one); + self.builder + .ins() + .stack_store(next_elmt_index, elmt_index_var, 0); + + let contract_sum_value = + self.builder + .ins() + .stack_load(self.real_type, contract_sum.unwrap(), 0); + + self.jit_compile_store(name, elmt, elmt_index, contract_sum_value, translation)?; + } + + // increment index + let next_index = self.builder.ins().iadd(indices[i], one); + let block = self.builder.create_block(); + let loop_cond = + self.builder + .ins() + .icmp_imm(IntCC::UnsignedLessThan, next_index, expr_shape[i]); + self.builder + .ins() + .brif(loop_cond, blocks[i], &[next_index], block, &[]); + self.builder.seal_block(blocks[i]); + self.builder.seal_block(block); + self.builder.switch_to_block(block); + } + Ok(()) + } + + fn jit_compile_sparse_contraction_block( + &mut self, + name: &str, + elmt: &TensorBlock, + translation: &Translation, + is_tangent: bool, + ) -> Result<()> { + match translation.source { + TranslationFrom::SparseContraction { .. } => {} + _ => { + panic!("expected sparse contraction") + } + } + let int_type = self.int_type; + let zero = self.builder.ins().iconst(int_type, 0); + let one = self.builder.ins().iconst(int_type, 1); + let two = self.builder.ins().iconst(int_type, 2); + + let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap(); + let translation_index = self + .layout + .get_translation_index(elmt.expr_layout(), elmt.layout()) + .unwrap(); + let translation_index = translation_index + translation.get_from_index_in_data_layout(); + + // initialise the contract sum + let contract_sum_var = self.decl_stack_slot(self.real_type, None); + + // loop through each contraction + let block = self.builder.create_block(); + let contract_index = self.builder.append_block_param(block, self.int_type); + let initial_contract_index = zero; + let final_contract_index = self + .builder + .ins() + .iconst(int_type, i64::try_from(elmt.layout().nnz()).unwrap()); + self.builder.ins().jump(block, &[initial_contract_index]); + self.builder.switch_to_block(block); + + // start and end indices stored next to each other in the indices array + // start_index = translation_index + 2 * contract_index + let translation_index_val = self + .builder + .ins() + .iconst(int_type, i64::try_from(translation_index).unwrap()); + let double_contract_index = self.builder.ins().imul(two, contract_index); + let start_index = self + .builder + .ins() + .iadd(translation_index_val, double_contract_index); + // end_index = start_index + 1 + let end_index = self.builder.ins().iadd(start_index, one); + + // index into the indices array to get the start and end indices + // start_contract = indices[translation_index + 2 * contract_index] + // end_contract = indices[translation_index + 2 * contract_index + 1] + let indices_array = self + .builder + .ins() + .global_value(self.int_ptr_type, self.indices); + let ptr = self.ptr_add_offset(self.int_type, indices_array, start_index); + let start_contract = self + .builder + .ins() + .load(self.int_type, self.mem_flags, ptr, 0); + let ptr = self.ptr_add_offset(self.int_type, indices_array, end_index); + let end_contract = self + .builder + .ins() + .load(self.int_type, self.mem_flags, ptr, 0); + + // init sum + let fzero = self.fconst(0.0); + self.builder.ins().stack_store(fzero, contract_sum_var, 0); + + // loop through each element in the contraction + let contract_block = self.builder.create_block(); + let expr_index = self + .builder + .append_block_param(contract_block, self.int_type); + self.builder.ins().jump(contract_block, &[start_contract]); + self.builder.switch_to_block(contract_block); + + // loop body - load index from layout + let rank_val = self.builder.ins().iconst( + self.int_type, + i64::try_from(elmt.expr_layout().rank()).unwrap(), + ); + let elmt_index_mult_rank = self.builder.ins().imul(expr_index, rank_val); + let indices_int = (0..elmt.expr_layout().rank()) + // index = indices[layout_index + i + elmt_index * rank] + .map(|i| { + let layout_index_plus_offset = self + .builder + .ins() + .iconst(self.int_type, i64::try_from(layout_index + i).unwrap()); + let curr_index = self + .builder + .ins() + .iadd(elmt_index_mult_rank, layout_index_plus_offset); + let ptr = self.ptr_add_offset(self.int_type, indices_array, curr_index); + let index = self + .builder + .ins() + .load(self.int_type, self.mem_flags, ptr, 0); + Ok(index) + }) + .collect::, anyhow::Error>>()?; + + // loop body - eval expression and increment sum + let expr = if is_tangent { + elmt.tangent_expr() + } else { + elmt.expr() + }; + let float_value = + self.jit_compile_expr(name, expr, indices_int.as_slice(), elmt, Some(expr_index))?; + let contract_sum_value = self + .builder + .ins() + .stack_load(self.real_type, contract_sum_var, 0); + let new_contract_sum_value = self.builder.ins().fadd(contract_sum_value, float_value); + self.builder + .ins() + .stack_store(new_contract_sum_value, contract_sum_var, 0); + + // increment contract loop index + let next_elmt_index = self.builder.ins().iadd(expr_index, one); + + // contract loop condition + let loop_while = + self.builder + .ins() + .icmp(IntCC::UnsignedLessThan, next_elmt_index, end_contract); + let post_contract_block = self.builder.create_block(); + self.builder.ins().brif( + loop_while, + contract_block, + &[next_elmt_index], + post_contract_block, + &[], + ); + self.builder.seal_block(contract_block); + self.builder.seal_block(post_contract_block); + + self.builder.switch_to_block(post_contract_block); + + // store the result + self.jit_compile_store( + name, + elmt, + contract_index, + new_contract_sum_value, + translation, + )?; + + // increment outer loop index + let next_contract_index = self.builder.ins().iadd(contract_index, one); + + // outer loop condition + let loop_while = self.builder.ins().icmp( + IntCC::UnsignedLessThan, + next_contract_index, + final_contract_index, + ); + let post_block = self.builder.create_block(); + self.builder + .ins() + .brif(loop_while, block, &[next_contract_index], post_block, &[]); + self.builder.seal_block(block); + self.builder.switch_to_block(post_block); + self.builder.seal_block(post_block); + + Ok(()) + } + + // for sparse blocks we can loop through the non-zero elements and extract the index from the layout, then we compile the expression passing in this index + // TODO: havn't implemented contractions yet + fn jit_compile_sparse_block( + &mut self, + name: &str, + elmt: &TensorBlock, + translation: &Translation, + is_tangent: bool, + ) -> Result<()> { + let int_type = self.int_type; + + let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap(); + + // loop through the non-zero elements + let zero = self.builder.ins().iconst(int_type, 0); + let one = self.builder.ins().iconst(int_type, 1); + let start_index = zero; + let end_index = self + .builder + .ins() + .iconst(int_type, i64::try_from(elmt.layout().nnz()).unwrap()); + + let block = self.builder.create_block(); + let curr_index = self.builder.append_block_param(block, int_type); + self.builder.ins().jump(block, &[start_index]); + self.builder.switch_to_block(block); + + // loop body - load index from layout + let elmt_index = curr_index; + let rank_val = self + .builder + .ins() + .iconst(int_type, i64::try_from(elmt.expr_layout().rank()).unwrap()); + let elmt_index_mult_rank = self.builder.ins().imul(elmt_index, rank_val); + let indices_int = (0..elmt.expr_layout().rank()) + // index = indices[layout_index + i + elmt_index * rank] + .map(|i| { + let layout_index_plus_offset = self + .builder + .ins() + .iconst(int_type, i64::try_from(layout_index + i).unwrap()); + let curr_index = self + .builder + .ins() + .iadd(elmt_index_mult_rank, layout_index_plus_offset); + let indices_ptr = self + .builder + .ins() + .global_value(self.int_ptr_type, self.indices); + let ptr = self.ptr_add_offset(self.int_type, indices_ptr, curr_index); + let index = self + .builder + .ins() + .load(self.int_type, self.mem_flags, ptr, 0); + Ok(index) + }) + .collect::, anyhow::Error>>()?; + + // loop body - eval expression + let expr = if is_tangent { + elmt.tangent_expr() + } else { + elmt.expr() + }; + let float_value = + self.jit_compile_expr(name, expr, indices_int.as_slice(), elmt, Some(elmt_index))?; + + self.jit_compile_broadcast_and_store( + name, + elmt, + float_value, + elmt_index, + translation, + block, + )?; + + // increment loop index + let next_index = self.builder.ins().iadd(elmt_index, one); + + // loop condition + let loop_while = self + .builder + .ins() + .icmp(IntCC::UnsignedLessThan, next_index, end_index); + let post_block = self.builder.create_block(); + + self.builder + .ins() + .brif(loop_while, block, &[next_index], post_block, &[]); + self.builder.seal_block(block); + self.builder.switch_to_block(post_block); + self.builder.seal_block(post_block); + Ok(()) + } + + // for diagonal blocks we can loop through the diagonal elements and the index is just the same for each element, then we compile the expression passing in this index + fn jit_compile_diagonal_block( + &mut self, + name: &str, + elmt: &TensorBlock, + translation: &Translation, + is_tangent: bool, + ) -> Result<()> { + let int_type = self.int_type; + + // loop through the non-zero elements + let zero = self.builder.ins().iconst(int_type, 0); + let one = self.builder.ins().iconst(int_type, 1); + let block = self.builder.create_block(); + let start_index = zero; + let end_index = self + .builder + .ins() + .iconst(int_type, i64::try_from(elmt.expr_layout().nnz()).unwrap()); + let curr_index = self.builder.append_block_param(block, int_type); + self.builder.ins().jump(block, &[start_index]); + self.builder.switch_to_block(block); + + // loop body - index is just the same for each element + let elmt_index = curr_index; + let indices_int = vec![elmt_index; elmt.expr_layout().rank()]; + + // loop body - eval expression + let expr = if is_tangent { + elmt.tangent_expr() + } else { + elmt.expr() + }; + let float_value = + self.jit_compile_expr(name, expr, indices_int.as_slice(), elmt, Some(elmt_index))?; + + // loop body - store result + self.jit_compile_broadcast_and_store( + name, + elmt, + float_value, + elmt_index, + translation, + block, + )?; + + // increment loop index + let next_index = self.builder.ins().iadd(elmt_index, one); + let loop_while = self + .builder + .ins() + .icmp(IntCC::UnsignedLessThan, next_index, end_index); + let post_block = self.builder.create_block(); + self.builder + .ins() + .brif(loop_while, block, &[next_index], post_block, &[]); + self.builder.seal_block(block); + self.builder.switch_to_block(post_block); + self.builder.seal_block(post_block); + + Ok(()) + } + + fn jit_compile_broadcast_and_store( + &mut self, + name: &str, + elmt: &TensorBlock, + float_value: Value, + expr_index: Value, + translation: &Translation, + pre_block: Block, + ) -> Result { + let int_type = self.int_type; + let one = self.builder.ins().iconst(int_type, 1); + let zero = self.builder.ins().iconst(int_type, 0); + match translation.source { + TranslationFrom::Broadcast { + broadcast_by: _, + broadcast_len, + } => { + let bcast_block = self.builder.create_block(); + let bcast_start_index = zero; + let bcast_end_index = self + .builder + .ins() + .iconst(int_type, i64::try_from(broadcast_len).unwrap()); + let bcast_index = self.builder.append_block_param(bcast_block, self.int_type); + + // setup loop block + self.builder.ins().jump(bcast_block, &[bcast_start_index]); + self.builder.switch_to_block(bcast_block); + + // store value at index = expr_index * broadcast_len + bcast_index + let tmp = self.builder.ins().imul(expr_index, bcast_end_index); + let store_index = self.builder.ins().iadd(tmp, bcast_index); + self.jit_compile_store(name, elmt, store_index, float_value, translation)?; + + // increment index + let bcast_next_index = self.builder.ins().iadd(bcast_index, one); + let bcast_cond = self.builder.ins().icmp( + IntCC::UnsignedLessThan, + bcast_next_index, + bcast_end_index, + ); + let post_bcast_block = self.builder.create_block(); + self.builder.ins().brif( + bcast_cond, + bcast_block, + &[bcast_next_index], + post_bcast_block, + &[], + ); + self.builder.seal_block(bcast_block); + self.builder.seal_block(post_bcast_block); + self.builder.switch_to_block(post_bcast_block); + + // return the current block for later + Ok(post_bcast_block) + } + TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => { + self.jit_compile_store(name, elmt, expr_index, float_value, translation)?; + Ok(pre_block) + } + _ => Err(anyhow!("Invalid translation")), + } + } + + fn jit_compile_store( + &mut self, + _name: &str, + elmt: &TensorBlock, + store_index: Value, + float_value: Value, + translation: &Translation, + ) -> Result<()> { + let int_type = self.int_type; + let rank = elmt.layout().rank(); + let res_index = match &translation.target { + TranslationTo::Contiguous { start, end: _ } => { + let start_const = self + .builder + .ins() + .iconst(int_type, i64::try_from(*start).unwrap()); + self.builder.ins().iadd(start_const, store_index) + } + TranslationTo::Sparse { indices: _ } => { + // load store index from layout + let translate_index = self + .layout + .get_translation_index(elmt.expr_layout(), elmt.layout()) + .unwrap(); + let translate_store_index = + translate_index + translation.get_to_index_in_data_layout(); + let translate_store_index = self + .builder + .ins() + .iconst(int_type, i64::try_from(translate_store_index).unwrap()); + let rank_const = self + .builder + .ins() + .iconst(int_type, i64::try_from(rank).unwrap()); + let elmt_index_strided = self.builder.ins().imul(store_index, rank_const); + let curr_index = self + .builder + .ins() + .iadd(elmt_index_strided, translate_store_index); + let indices_ptr = self + .builder + .ins() + .global_value(self.int_ptr_type, self.indices); + let ptr = self.ptr_add_offset(self.int_type, indices_ptr, curr_index); + self.builder + .ins() + .load(self.int_type, self.mem_flags, ptr, 0) + } + }; + + let ptr = self.ptr_add_offset(self.real_type, self.tensor_ptr.unwrap(), res_index); + self.builder + .ins() + .store(self.mem_flags, float_value, ptr, 0); + + Ok(()) + } + + fn declare_variable(&mut self, ty: types::Type, name: &str, val: Value) -> Variable { + let index = self.variables.len(); + let var = Variable::new(index); + if !self.variables.contains_key(name) { + self.variables.insert(name.into(), var); + self.builder.declare_var(var, ty); + self.builder.def_var(var, val); + } + var + } + + fn get_tangent_tensor_name(&self, name: &str) -> String { + format!("{}__tangent__", name) + } + + fn insert_tensor(&mut self, tensor: &Tensor, ptr: Value, data_index: i64, is_tangent: bool) { + let mut tensor_data_index = data_index; + let tensor_data_ptr = self.ptr_add_offset_i64(self.real_type, ptr, tensor_data_index); + let tensor_name = if is_tangent { + self.get_tangent_tensor_name(tensor.name()) + } else { + tensor.name().to_owned() + }; + self.declare_variable(self.real_ptr_type, tensor_name.as_str(), tensor_data_ptr); + + //insert any named blocks + for blk in tensor.elmts() { + if let Some(name) = blk.name() { + let blk_name = if is_tangent { + self.get_tangent_tensor_name(name) + } else { + name.to_owned() + }; + let tensor_data_ptr = + self.ptr_add_offset_i64(self.real_type, ptr, tensor_data_index); + self.declare_variable(self.real_ptr_type, blk_name.as_str(), tensor_data_ptr); + } + // named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index + tensor_data_index += i64::try_from(blk.nnz()).unwrap(); + } + } + + pub fn new( + module: &'ctx mut CraneliftModule, + model: &DiscreteModel, + arg_names: &[&str], + arg_types: &[Type], + ) -> Self { + module.ctx.func.signature.params.clear(); + module.ctx.func.signature.returns.clear(); + + for ty in arg_types { + module.ctx.func.signature.params.push(AbiParam::new(*ty)); + } + + // Create the builder to build a function. + let mut builder = FunctionBuilder::new(&mut module.ctx.func, &mut module.builder_context); + + let indices = module + .module + .declare_data_in_func(module.indices_id, builder.func); + + // Create the entry block, to start emitting code in. + let entry_block = builder.create_block(); + + // Since this is the entry block, add block parameters corresponding to + // the function's parameters. + // + // TODO: Streamline the API here. + builder.append_block_params_for_function_params(entry_block); + + // Tell the builder to emit code in this block. + builder.switch_to_block(entry_block); + + // And, tell the builder that this block will have no further + // predecessors. Since it's the entry block, it won't have any + // predecessors. + builder.seal_block(entry_block); + + let mut codegen = Self { + int_type: module.int_type, + real_type: module.real_type, + real_ptr_type: module.real_ptr_type, + int_ptr_type: module.int_ptr_type, + builder, + module: &mut module.module, + tensor_ptr: None, + indices, + variables: HashMap::new(), + mem_flags: MemFlags::new(), + functions: HashMap::new(), + layout: &module.layout, + }; + + // insert arg vars + for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types.iter()).enumerate() { + let val = codegen.builder.block_params(entry_block)[i]; + codegen.declare_variable(*arg_type, arg_name, val); + } + + // insert u if it exists in args + if let Some(u) = codegen.variables.get("u") { + let u_ptr = codegen.builder.use_var(*u); + codegen.insert_tensor(model.state(), u_ptr, 0, false); + } + + if let Some(du) = codegen.variables.get("du") { + let du_ptr = codegen.builder.use_var(*du); + codegen.insert_tensor(model.state(), du_ptr, 0, true); + } + + // insert dudt if it exists in args and is used in the model + if let Some(dudt) = codegen.variables.get("dudt") { + if let Some(state_dot) = model.state_dot() { + let statedot_ptr = codegen.builder.use_var(*dudt); + codegen.insert_tensor(state_dot, statedot_ptr, 0, false); + } + } + + // insert all tensors in data if it exists in args + let tensors = model.inputs().iter(); + let tensors = tensors.chain(model.time_indep_defns().iter()); + let tensors = tensors.chain(model.time_dep_defns().iter()); + let tensors = tensors.chain(model.state_dep_defns().iter()); + let mut others = Vec::new(); + others.push(model.out()); + others.push(model.rhs()); + if let Some(lhs) = model.lhs() { + others.push(lhs); + } + let tensors = tensors.chain(others); + + if let Some(data) = codegen.variables.get("data") { + let data_ptr = codegen.builder.use_var(*data); + + for tensor in tensors.clone() { + let data_index = + i64::try_from(codegen.layout.get_data_index(tensor.name()).unwrap()).unwrap(); + codegen.insert_tensor(tensor, data_ptr, data_index, false); + } + } + + // insert all tangent tensors in tangent_data if it exists in args + if let Some(data) = codegen.variables.get("ddata") { + let data_ptr = codegen.builder.use_var(*data); + + for tensor in tensors { + let data_index = + i64::try_from(codegen.layout.get_data_index(tensor.name()).unwrap()).unwrap(); + codegen.insert_tensor(tensor, data_ptr, data_index, true); + } + } + codegen + } + + fn jit_compile_set_inputs( + &mut self, + model: &DiscreteModel, + base_data_ptr: Value, + is_tangent: bool, + ) { + let mut inputs_index = 0; + for input in model.inputs() { + let data_index = + i64::try_from(self.layout.get_data_index(input.name()).unwrap()).unwrap(); + self.insert_tensor(input, base_data_ptr, data_index, is_tangent); + let tensor_name = if is_tangent { + self.get_tangent_tensor_name(input.name()) + } else { + input.name().to_owned() + }; + let data_ptr = self.variables.get(tensor_name.as_str()).unwrap(); + let data_ptr = self.builder.use_var(*data_ptr); + let input_name = if is_tangent { "dinputs" } else { "inputs" }; + let input_ptr = self.variables.get(input_name).unwrap(); + let input_ptr = self.builder.use_var(*input_ptr); + let inputs_start_index = self + .builder + .ins() + .iconst(self.int_type, i64::try_from(inputs_index).unwrap()); + + // loop thru the elements of this input and set them using the inputs ptr + let start_index = self.builder.ins().iconst(self.int_type, 0); + + let input_block = self.builder.create_block(); + let curr_input_index = self.builder.append_block_param(input_block, self.int_type); + self.builder.ins().jump(input_block, &[start_index]); + self.builder.switch_to_block(input_block); + + // loop body - copy value from inputs to data + let curr_input_index_plus_start_index = self + .builder + .ins() + .iadd(curr_input_index, inputs_start_index); + let indexed_input_ptr = + self.ptr_add_offset(self.real_type, input_ptr, curr_input_index_plus_start_index); + let indexed_data_ptr = self.ptr_add_offset(self.real_type, data_ptr, curr_input_index); + let input_value = + self.builder + .ins() + .load(self.real_type, self.mem_flags, indexed_input_ptr, 0); + self.builder + .ins() + .store(self.mem_flags, input_value, indexed_data_ptr, 0); + + // increment loop index + let one = self.builder.ins().iconst(self.int_type, 1); + let next_index = self.builder.ins().iadd(curr_input_index, one); + + let loop_while = self.builder.ins().icmp_imm( + IntCC::UnsignedLessThan, + next_index, + i64::try_from(input.nnz()).unwrap(), + ); + let post_block = self.builder.create_block(); + self.builder + .ins() + .brif(loop_while, input_block, &[next_index], post_block, &[]); + self.builder.seal_block(input_block); + self.builder.seal_block(post_block); + self.builder.switch_to_block(post_block); + + // get ready for next input + inputs_index += input.nnz(); + } + } +} diff --git a/src/execution/cranelift/mod.rs b/src/execution/cranelift/mod.rs new file mode 100644 index 0000000..24ccbdd --- /dev/null +++ b/src/execution/cranelift/mod.rs @@ -0,0 +1 @@ +pub mod codegen; diff --git a/src/execution/functions.rs b/src/execution/functions.rs new file mode 100644 index 0000000..b106f2a --- /dev/null +++ b/src/execution/functions.rs @@ -0,0 +1,226 @@ +#![allow(clippy::type_complexity)] +pub const FUNCTIONS: &[( + &str, + extern "C" fn(f64) -> f64, + extern "C" fn(f64, f64) -> f64, +)] = &[ + ("sin", sin, dsin), + ("cos", cos, dcos), + ("tan", tan, dtan), + ("exp", exp, dexp), + ("log", log, dlog), + ("log10", log10, dlog10), + ("sqrt", sqrt, dsqrt), + ("abs", abs, dabs), + ("sigmoid", sigmoid, dsigmoid), + ("arcsinh", arcsinh, darcsinh), + ("arccosh", arccosh, darccosh), + ("heaviside", heaviside, dheaviside), + ("tanh", tanh, dtanh), + ("sinh", sinh, dsinh), + ("cosh", cosh, dcosh), +]; + +pub const TWO_ARG_FUNCTIONS: &[( + &str, + extern "C" fn(f64, f64) -> f64, + extern "C" fn(f64, f64, f64, f64) -> f64, +)] = &[ + ("copysign", copysign, dcopysign), + ("pow", pow, dpow), + ("min", min, dmin), + ("max", max, dmax), +]; + +pub fn function_num_args(name: &str, is_tangent: bool) -> Option { + let one = FUNCTIONS.iter().find(|(n, _, _)| n == &name); + let multiplier = if is_tangent { 2 } else { 1 }; + if one.is_some() { + return Some(multiplier); + } + let two = TWO_ARG_FUNCTIONS.iter().find(|(n, _, _)| n == &name); + if two.is_some() { + return Some(2 * multiplier); + } + None +} + +extern "C" fn sin(x: f64) -> f64 { + x.sin() +} + +extern "C" fn dsin(x: f64, dx: f64) -> f64 { + x.cos() * dx +} + +extern "C" fn cos(x: f64) -> f64 { + x.cos() +} + +extern "C" fn dcos(x: f64, dx: f64) -> f64 { + -x.sin() * dx +} + +extern "C" fn tan(x: f64) -> f64 { + x.tan() +} + +extern "C" fn dtan(x: f64, dx: f64) -> f64 { + let sec = x.cos().powi(-2); + sec * dx +} + +extern "C" fn exp(x: f64) -> f64 { + x.exp() +} + +extern "C" fn dexp(x: f64, dx: f64) -> f64 { + x.exp() * dx +} + +extern "C" fn log(x: f64) -> f64 { + x.ln() +} + +extern "C" fn dlog(x: f64, dx: f64) -> f64 { + dx / x +} + +extern "C" fn log10(x: f64) -> f64 { + x.log10() +} + +extern "C" fn dlog10(x: f64, dx: f64) -> f64 { + dx / (x * 10.0_f64.ln()) +} + +extern "C" fn sqrt(x: f64) -> f64 { + x.sqrt() +} + +extern "C" fn dsqrt(x: f64, dx: f64) -> f64 { + 0.5 * dx / x.sqrt() +} + +extern "C" fn abs(x: f64) -> f64 { + x.abs() +} + +extern "C" fn dabs(x: f64, dx: f64) -> f64 { + if x > 0.0 { + dx + } else { + -dx + } +} + +extern "C" fn copysign(x: f64, y: f64) -> f64 { + x.copysign(y) +} + +// todo: this is not correct if b(x) == 0 +extern "C" fn dcopysign(_x: f64, dx: f64, y: f64, _dy: f64) -> f64 { + dx.copysign(y) +} + +extern "C" fn pow(x: f64, y: f64) -> f64 { + x.powf(y) +} + +// d/dx(f(x)^g(x)) = f(x)^(g(x) - 1) (g(x) f'(x) + f(x) log(f(x)) g'(x)) +extern "C" fn dpow(x: f64, dx: f64, y: f64, dy: f64) -> f64 { + x.powf(y - 1.0) * (y * dx + x * dx.ln() * dy) +} + +extern "C" fn min(x: f64, y: f64) -> f64 { + x.min(y) +} + +extern "C" fn dmin(x: f64, dx: f64, y: f64, dy: f64) -> f64 { + if x < y { + dx + } else { + dy + } +} + +extern "C" fn max(x: f64, y: f64) -> f64 { + x.max(y) +} + +extern "C" fn dmax(x: f64, dx: f64, y: f64, dy: f64) -> f64 { + if x > y { + dx + } else { + dy + } +} + +extern "C" fn sigmoid(x: f64) -> f64 { + 1.0 / (1.0 + (-x).exp()) +} + +// (f'(x))/(2 cosh(f(x)) + 2) +extern "C" fn dsigmoid(x: f64, dx: f64) -> f64 { + let cosh = x.cosh(); + dx / (2.0 * cosh + 2.0) +} + +extern "C" fn arcsinh(x: f64) -> f64 { + x.asinh() +} + +// d/dx(sinh^(-1)(f(x))) = (f'(x))/sqrt(f(x)^2 + 1) +extern "C" fn darcsinh(x: f64, dx: f64) -> f64 { + dx / (x.powi(2) + 1.0).sqrt() +} + +extern "C" fn arccosh(x: f64) -> f64 { + x.acosh() +} + +// d/dx(cosh^(-1)(f(x))) = (f'(x))/(sqrt(f(x) - 1) sqrt(f(x) + 1)) +extern "C" fn darccosh(x: f64, dx: f64) -> f64 { + dx / ((x - 1.0).sqrt() * (x + 1.0).sqrt()) +} + +extern "C" fn heaviside(x: f64) -> f64 { + if x >= 0.0 { + 1.0 + } else { + 0.0 + } +} + +// todo: not correct at a(x) == 0 +extern "C" fn dheaviside(_x: f64, _dx: f64) -> f64 { + 0.0 +} + +extern "C" fn tanh(x: f64) -> f64 { + x.tanh() +} + +// (f'(x))/(cosh^2(f(x))) +extern "C" fn dtanh(x: f64, dx: f64) -> f64 { + let cosh = x.cosh(); + dx / cosh.powi(2) +} + +extern "C" fn sinh(x: f64) -> f64 { + x.sinh() +} + +// d/dx(sinh(f(x))) = f'(x) cosh(f(x)) +extern "C" fn dsinh(x: f64, dx: f64) -> f64 { + dx * x.cosh() +} + +extern "C" fn cosh(x: f64) -> f64 { + x.cosh() +} + +// d/dx(cosh(f(x))) = f'(x) sinh(f(x)) +extern "C" fn dcosh(x: f64, dx: f64) -> f64 { + dx * x.sinh() +} diff --git a/src/execution/interface.rs b/src/execution/interface.rs new file mode 100644 index 0000000..6b5fde2 --- /dev/null +++ b/src/execution/interface.rs @@ -0,0 +1,65 @@ +type RealType = f64; + +pub type StopFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + data: *mut RealType, + root: *mut RealType, +); +pub type RhsFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + data: *mut RealType, + rr: *mut RealType, +); +pub type RhsGradientFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + du: *const RealType, + data: *mut RealType, + ddata: *mut RealType, + rr: *mut RealType, + drr: *mut RealType, +); +pub type MassFunc = unsafe extern "C" fn( + time: RealType, + v: *const RealType, + data: *mut RealType, + mv: *mut RealType, +); +pub type U0Func = unsafe extern "C" fn(data: *mut RealType, u: *mut RealType); +pub type U0GradientFunc = unsafe extern "C" fn( + data: *mut RealType, + ddata: *mut RealType, + u: *mut RealType, + du: *mut RealType, +); +pub type CalcOutFunc = + unsafe extern "C" fn(time: RealType, u: *const RealType, data: *mut RealType); +pub type CalcOutGradientFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + du: *const RealType, + data: *mut RealType, + ddata: *mut RealType, +); +pub type GetDimsFunc = unsafe extern "C" fn( + states: *mut u32, + inputs: *mut u32, + outputs: *mut u32, + data: *mut u32, + stop: *mut u32, +); +pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const RealType, data: *mut RealType); +pub type SetInputsGradientFunc = unsafe extern "C" fn( + inputs: *const RealType, + dinputs: *const RealType, + data: *mut RealType, + ddata: *mut RealType, +); +pub type SetIdFunc = unsafe extern "C" fn(id: *mut RealType); +pub type GetOutFunc = unsafe extern "C" fn( + data: *const RealType, + tensor_data: *mut *mut RealType, + tensor_size: *mut u32, +); diff --git a/src/execution/codegen.rs b/src/execution/llvm/codegen.rs similarity index 89% rename from src/execution/codegen.rs rename to src/execution/llvm/codegen.rs index cb3fb78..0b60e07 100644 --- a/src/execution/codegen.rs +++ b/src/execution/llvm/codegen.rs @@ -1,20 +1,28 @@ +use aliasable::boxed::AliasableBox; use anyhow::{anyhow, Result}; use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::basic_block::BasicBlock; use inkwell::builder::Builder; -use inkwell::context::AsContextRef; +use inkwell::context::{AsContextRef, Context}; +use inkwell::execution_engine::{ExecutionEngine, JitFunction, UnsafeFunctionPointer}; use inkwell::intrinsics::Intrinsic; use inkwell::module::Module; -use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, IntType}; +use inkwell::passes::PassBuilderOptions; +use inkwell::targets::{InitializationConfig, Target, TargetTriple}; +use inkwell::types::{ + BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType, +}; use inkwell::values::{ AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, FloatValue, FunctionValue, GlobalValue, IntValue, PointerValue, }; -use inkwell::{AddressSpace, FloatPredicate, IntPredicate}; +use inkwell::{AddressSpace, FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell_internals::llvm_versions; use llvm_sys::prelude::LLVMValueRef; use std::collections::HashMap; use std::iter::zip; +use std::pin::Pin; +use target_lexicon::Triple; type RealType = f64; @@ -28,75 +36,217 @@ use crate::enzyme::{ FreeTypeAnalysis, IntList, LLVMOpaqueContext, LLVMOpaqueValue, CDIFFE_TYPE_DFT_CONSTANT, CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED, }; +use crate::execution::module::CodegenModule; use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo}; -/// Convenience type alias for the `sum` function. -/// -/// Calling this is innately `unsafe` because there's no guarantee it doesn't -/// do `unsafe` operations internally. -pub type StopFunc = unsafe extern "C" fn( - time: RealType, - u: *const RealType, - data: *mut RealType, - root: *mut RealType, -); -pub type RhsFunc = unsafe extern "C" fn( - time: RealType, - u: *const RealType, - data: *mut RealType, - rr: *mut RealType, -); -pub type RhsGradientFunc = unsafe extern "C" fn( - time: RealType, - u: *const RealType, - du: *const RealType, - data: *mut RealType, - ddata: *mut RealType, - rr: *mut RealType, - drr: *mut RealType, -); -pub type MassFunc = unsafe extern "C" fn( - time: RealType, - v: *const RealType, - data: *mut RealType, - mv: *mut RealType, -); -pub type U0Func = unsafe extern "C" fn(data: *mut RealType, u: *mut RealType); -pub type U0GradientFunc = unsafe extern "C" fn( - data: *mut RealType, - ddata: *mut RealType, - u: *mut RealType, - du: *mut RealType, -); -pub type CalcOutFunc = - unsafe extern "C" fn(time: RealType, u: *const RealType, data: *mut RealType); -pub type CalcOutGradientFunc = unsafe extern "C" fn( - time: RealType, - u: *const RealType, - du: *const RealType, - data: *mut RealType, - ddata: *mut RealType, -); -pub type GetDimsFunc = unsafe extern "C" fn( - states: *mut u32, - inputs: *mut u32, - outputs: *mut u32, - data: *mut u32, - stop: *mut u32, -); -pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const RealType, data: *mut RealType); -pub type SetInputsGradientFunc = unsafe extern "C" fn( - inputs: *const RealType, - dinputs: *const RealType, - data: *mut RealType, - ddata: *mut RealType, -); -pub type SetIdFunc = unsafe extern "C" fn(id: *mut RealType); -pub type GetOutFunc = unsafe extern "C" fn( - data: *const RealType, - tensor_data: *mut *mut RealType, - tensor_size: *mut u32, -); +struct ImmovableLlvmModule { + // actually has lifetime of `context` + // declared first so it's droped before `context` + codegen: Option>, + // safety: we must never move out of this box as long as codgen is alive + context: AliasableBox, + triple: Triple, + _pin: std::marker::PhantomPinned, +} + +pub struct LlvmModule(Pin>); + +impl LlvmModule { + pub fn print(&self) { + self.codegen().module().print_to_stderr(); + } + fn codegen_mut(&mut self) -> &mut CodeGen<'static> { + unsafe { + self.0 + .as_mut() + .get_unchecked_mut() + .codegen + .as_mut() + .unwrap() + } + } + fn codegen(&self) -> &CodeGen<'static> { + self.0.as_ref().get_ref().codegen.as_ref().unwrap() + } + pub fn jit2( + &mut self, + name: &str, + ) -> Result> { + let maybe_fn = unsafe { self.codegen_mut().ee.get_function::(name) }; + match maybe_fn { + Ok(f) => Ok(f), + Err(err) => Err(anyhow!("Error during jit for {}: {}", name, err)), + } + } +} + +impl CodegenModule for LlvmModule { + type FuncId = FunctionValue<'static>; + fn new(triple: Triple, model: &DiscreteModel) -> Result { + let context = AliasableBox::from_unique(Box::new(Context::create())); + let mut pinned = Self(Box::pin(ImmovableLlvmModule { + codegen: None, + context, + triple, + _pin: std::marker::PhantomPinned, + })); + + let context_ref = pinned.0.context.as_ref(); + let real_type_str = "f64"; + let codegen = CodeGen::new( + model, + context_ref, + context_ref.f64_type(), + context_ref.i32_type(), + real_type_str, + )?; + let codegen = unsafe { std::mem::transmute::, CodeGen<'static>>(codegen) }; + unsafe { pinned.0.as_mut().get_unchecked_mut().codegen = Some(codegen) }; + Ok(pinned) + } + + fn layout(&self) -> &DataLayout { + &self.codegen().layout + } + + fn jit(&mut self, func_id: Self::FuncId) -> Result<*const u8> { + let name = func_id.get_name().to_str().unwrap(); + let maybe_fn = self.codegen_mut().ee.get_function_address(name); + match maybe_fn { + Ok(f) => Ok(f as *const u8), + Err(err) => Err(anyhow!("Error during jit for {}: {}", name, err)), + } + } + + fn compile_set_u0(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_set_u0(model) + } + + fn compile_calc_out(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_calc_out(model) + } + + fn compile_calc_stop(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_calc_stop(model) + } + + fn compile_rhs(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_rhs(model) + } + + fn compile_mass(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_mass(model) + } + + fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_get_dims(model) + } + + fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result { + self.codegen_mut().compile_get_tensor(model, name) + } + + fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_set_inputs(model) + } + + fn compile_set_id(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_set_id(model) + } + + fn compile_set_u0_grad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[CompileGradientArgType::Dup, CompileGradientArgType::Dup], + ) + } + + fn compile_rhs_grad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Dup, + CompileGradientArgType::Dup, + CompileGradientArgType::DupNoNeed, + ], + ) + } + + fn compile_calc_out_grad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Dup, + CompileGradientArgType::Dup, + ], + ) + } + + fn compile_set_inputs_grad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[CompileGradientArgType::Dup, CompileGradientArgType::Dup], + ) + } + + fn pre_autodiff_optimisation(&mut self) -> Result<()> { + // optimise at -O2 no unrolling before giving to enzyme + let pass_options = PassBuilderOptions::create(); + //pass_options.set_verify_each(true); + //pass_options.set_debug_logging(true); + //pass_options.set_loop_interleaving(true); + pass_options.set_loop_vectorization(false); + pass_options.set_loop_slp_vectorization(false); + pass_options.set_loop_unrolling(false); + //pass_options.set_forget_all_scev_in_loop_unroll(true); + //pass_options.set_licm_mssa_opt_cap(1); + //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10); + //pass_options.set_call_graph_profile(true); + //pass_options.set_merge_functions(true); + + let initialization_config = &InitializationConfig::default(); + Target::initialize_all(initialization_config); + let triple = TargetTriple::create(self.0.triple.to_string().as_str()); + let target = Target::from_triple(&triple).unwrap(); + let machine = target + .create_target_machine( + &triple, + "generic", //TargetMachine::get_host_cpu_name().to_string().as_str(), + "", //TargetMachine::get_host_cpu_features().to_string().as_str(), + inkwell::OptimizationLevel::Default, + inkwell::targets::RelocMode::Default, + inkwell::targets::CodeModel::Default, + ) + .unwrap(); + + self.codegen_mut() + .module() + .run_passes("default", &machine, pass_options) + .map_err(|e| anyhow!("Failed to run passes: {:?}", e)) + } + + fn post_autodiff_optimisation(&mut self) -> Result<()> { + Ok(()) + } +} struct Globals<'ctx> { indices: Option>, @@ -120,7 +270,6 @@ impl<'ctx> Globals<'ctx> { .map(|&i| int_type.const_int(i.try_into().unwrap(), false)) .collect::>(); let indices_value = int_type.const_array(indices_array_values.as_slice()); - let _int_ptr_type = int_type.ptr_type(AddressSpace::default()); let globals = Self { indices: Some(module.add_global( indices_array_type, @@ -148,37 +297,49 @@ pub struct CodeGen<'ctx> { fn_value_opt: Option>, tensor_ptr_opt: Option>, real_type: FloatType<'ctx>, + real_ptr_type: PointerType<'ctx>, real_type_str: String, int_type: IntType<'ctx>, + int_ptr_type: PointerType<'ctx>, layout: DataLayout, globals: Globals<'ctx>, + ee: ExecutionEngine<'ctx>, } impl<'ctx> CodeGen<'ctx> { pub fn new( model: &DiscreteModel, context: &'ctx inkwell::context::Context, - module: Module<'ctx>, real_type: FloatType<'ctx>, + int_type: IntType<'ctx>, real_type_str: &str, - ) -> Self { + ) -> Result { let builder = context.create_builder(); let layout = DataLayout::new(model); + let module = context.create_module(model.name()); let globals = Globals::new(&layout, context, &module); - Self { + let ee = module + .create_jit_execution_engine(OptimizationLevel::Aggressive) + .map_err(|e| anyhow::anyhow!("Error creating execution engine: {:?}", e))?; + let real_ptr_type = Self::pointer_type(context, real_type.into()); + let int_ptr_type = Self::pointer_type(context, int_type.into()); + Ok(Self { context, module, builder, real_type, + real_ptr_type, real_type_str: real_type_str.to_owned(), variables: HashMap::new(), functions: HashMap::new(), fn_value_opt: None, tensor_ptr_opt: None, layout, - int_type: context.i32_type(), + int_type, + int_ptr_type, globals, - } + ee, + }) } pub fn write_bitcode_to_path(&self, path: &std::path::Path) { @@ -205,6 +366,26 @@ impl<'ctx> CodeGen<'ctx> { self.insert_tensor(model.rhs()); } + #[llvm_versions(4.0..=14.0)] + fn pointer_type(_context: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> { + ty.ptr_type(AddressSpace::default()) + } + + #[llvm_versions(15.0..=latest)] + fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> { + context.ptr_type(AddressSpace::default()) + } + + #[llvm_versions(4.0..=14.0)] + fn fn_pointer_type(_context: &'ctx Context, ty: FunctionType<'ctx>) -> PointerType<'ctx> { + ty.ptr_type(AddressSpace::default()) + } + + #[llvm_versions(15.0..=latest)] + fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> { + context.ptr_type(AddressSpace::default()) + } + #[llvm_versions(4.0..=14.0)] fn insert_indices(&mut self) { if let Some(indices) = self.globals.indices.as_ref() { @@ -1395,7 +1576,7 @@ impl<'ctx> CodeGen<'ctx> { self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index) } AstKind::Number(value) => Ok(self.real_type.const_float(*value)), - AstKind::IndexedName(iname) => { + AstKind::Name(iname) => { let ptr = self.get_param(iname.name); let layout = self.layout.get_layout(iname.name).unwrap(); let iname_elmt_index = if layout.is_dense() { @@ -1451,13 +1632,6 @@ impl<'ctx> CodeGen<'ctx> { .build_load(self.real_type, value_ptr, name)? .into_float_value()) } - AstKind::Name(name) => { - // must be a scalar, just load the value - let ptr = self.get_param(name); - Ok(self - .build_load(self.real_type, *ptr, name)? - .into_float_value()) - } AstKind::NamedGradient(name) => { let name_str = name.to_string(); let ptr = self.get_param(name_str.as_str()); @@ -1496,10 +1670,12 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_set_u0<'m>(&mut self, model: &'m DiscreteModel) -> Result> { self.clear(); - let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); - let fn_type = void_type.fn_type(&[real_ptr_type.into(), real_ptr_type.into()], false); - let fn_arg_names = &["data", "u0"]; + let fn_type = void_type.fn_type( + &[self.real_ptr_type.into(), self.real_ptr_type.into()], + false, + ); + let fn_arg_names = &["u0", "data"]; let function = self.module.add_function("set_u0", fn_type, None); // add noalias @@ -1546,13 +1722,12 @@ impl<'ctx> CodeGen<'ctx> { model: &'m DiscreteModel, ) -> Result> { self.clear(); - let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( &[ self.real_type.into(), - real_ptr_type.into(), - real_ptr_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), ], false, ); @@ -1601,14 +1776,13 @@ impl<'ctx> CodeGen<'ctx> { model: &'m DiscreteModel, ) -> Result> { self.clear(); - let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( &[ self.real_type.into(), - real_ptr_type.into(), - real_ptr_type.into(), - real_ptr_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), ], false, ); @@ -1655,14 +1829,13 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_rhs<'m>(&mut self, model: &'m DiscreteModel) -> Result> { self.clear(); - let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( &[ self.real_type.into(), - real_ptr_type.into(), - real_ptr_type.into(), - real_ptr_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), ], false, ); @@ -1719,14 +1892,13 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_mass<'m>(&mut self, model: &'m DiscreteModel) -> Result> { self.clear(); - let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( &[ self.real_type.into(), - real_ptr_type.into(), - real_ptr_type.into(), - real_ptr_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), ], false, ); @@ -1795,9 +1967,9 @@ impl<'ctx> CodeGen<'ctx> { // construct the gradient function let mut fn_type: Vec = Vec::new(); - let orig_fn_type_ptr = original_function - .get_type() - .ptr_type(AddressSpace::default()); + + let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type()); + let mut enzyme_fn_type: Vec = vec![orig_fn_type_ptr.into()]; let mut start_param_index: Vec = Vec::new(); let mut ptr_arg_indices: Vec = Vec::new(); @@ -1995,14 +2167,13 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result> { self.clear(); - let int_ptr_type = self.context.i32_type().ptr_type(AddressSpace::default()); let fn_type = self.context.void_type().fn_type( &[ - int_ptr_type.into(), - int_ptr_type.into(), - int_ptr_type.into(), - int_ptr_type.into(), - int_ptr_type.into(), + self.int_ptr_type.into(), + self.int_ptr_type.into(), + self.int_ptr_type.into(), + self.int_ptr_type.into(), + self.int_ptr_type.into(), ], false, ); @@ -2068,17 +2239,12 @@ impl<'ctx> CodeGen<'ctx> { name: &str, ) -> Result> { self.clear(); - let real_ptr_ptr_type = self - .real_type - .ptr_type(AddressSpace::default()) - .ptr_type(AddressSpace::default()); - let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); - let int_ptr_type = self.context.i32_type().ptr_type(AddressSpace::default()); + let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into()); let fn_type = self.context.void_type().fn_type( &[ - real_ptr_type.into(), + self.real_ptr_type.into(), real_ptr_ptr_type.into(), - int_ptr_type.into(), + self.int_ptr_type.into(), ], false, ); @@ -2121,9 +2287,11 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result> { self.clear(); - let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); - let fn_type = void_type.fn_type(&[real_ptr_type.into(), real_ptr_type.into()], false); + let fn_type = void_type.fn_type( + &[self.real_ptr_type.into(), self.real_ptr_type.into()], + false, + ); let function = self.module.add_function("set_inputs", fn_type, None); let mut block = self.context.append_basic_block(function, "entry"); self.fn_value_opt = Some(function); @@ -2217,9 +2385,8 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result> { self.clear(); - let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); - let fn_type = void_type.fn_type(&[real_ptr_type.into()], false); + let fn_type = void_type.fn_type(&[self.real_ptr_type.into()], false); let function = self.module.add_function("set_id", fn_type, None); let mut block = self.context.append_basic_block(function, "entry"); diff --git a/src/execution/llvm/compiler.rs b/src/execution/llvm/compiler.rs new file mode 100644 index 0000000..ac4454a --- /dev/null +++ b/src/execution/llvm/compiler.rs @@ -0,0 +1,1350 @@ +use anyhow::anyhow; +use inkwell::{ + passes::PassBuilderOptions, + targets::{CodeModel, InitializationConfig, RelocMode, Target, TargetMachine}, +}; +use std::env; +use std::path::Path; +use uid::Id; + +use crate::discretise::DiscreteModel; +use crate::parser::parse_ds_string; +use crate::utils::find_executable; +use crate::utils::find_runtime_path; +use anyhow::Result; +use inkwell::{ + context::Context, + execution_engine::{ExecutionEngine, JitFunction, UnsafeFunctionPointer}, + targets::{FileType, TargetTriple}, + OptimizationLevel, +}; +use ouroboros::self_referencing; +use std::process::Command; + +use super::codegen::CompileGradientArgType; +use super::codegen::GetDimsFunc; +use super::codegen::GetOutFunc; +use super::codegen::SetIdFunc; +use super::codegen::SetInputsFunc; +use super::codegen::SetInputsGradientFunc; +use super::codegen::U0GradientFunc; +use super::codegen::{CalcOutGradientFunc, MassFunc, RhsFunc, RhsGradientFunc}; +use super::{ + codegen::{CalcOutFunc, StopFunc, U0Func, CodeGen}, + super::data_layout::DataLayout, +}; + +struct JitFunctions<'ctx> { + set_u0: JitFunction<'ctx, U0Func>, + rhs: JitFunction<'ctx, RhsFunc>, + mass: JitFunction<'ctx, MassFunc>, + calc_out: JitFunction<'ctx, CalcOutFunc>, + calc_stop: JitFunction<'ctx, StopFunc>, + set_id: JitFunction<'ctx, SetIdFunc>, + get_dims: JitFunction<'ctx, GetDimsFunc>, + set_inputs: JitFunction<'ctx, SetInputsFunc>, + get_out: JitFunction<'ctx, GetOutFunc>, +} + +struct JitGradFunctions<'ctx> { + set_u0_grad: JitFunction<'ctx, U0GradientFunc>, + rhs_grad: JitFunction<'ctx, RhsGradientFunc>, + calc_out_grad: JitFunction<'ctx, CalcOutGradientFunc>, + set_inputs_grad: JitFunction<'ctx, SetInputsGradientFunc>, +} + +struct CompilerData<'ctx> { + codegen: CodeGen<'ctx>, + jit_functions: JitFunctions<'ctx>, + jit_grad_functions: JitGradFunctions<'ctx>, +} + +#[self_referencing] +pub struct LlvmCompiler { + context: Context, + + #[borrows(context)] + #[not_covariant] + data: CompilerData<'this>, + + number_of_states: usize, + number_of_parameters: usize, + number_of_outputs: usize, + has_mass: bool, + data_layout: DataLayout, + output_base_filename: String, +} + +impl LlvmCompiler { + const OPT_VARIENTS: [&'static str; 2] = ["opt-14", "opt"]; + const CLANG_VARIENTS: [&'static str; 2] = ["clang", "clang-14"]; + fn find_opt() -> Result<&'static str> { + find_executable(&LlvmCompiler::OPT_VARIENTS) + } + fn find_clang() -> Result<&'static str> { + find_executable(&LlvmCompiler::CLANG_VARIENTS) + } + /// search for the enzyme library in the environment variables + fn find_enzyme_lib() -> Result { + let env_vars = ["LD_LIBRARY_PATH", "DYLD_LIBRARY_PATH", "PATH"]; + for var in env_vars.iter() { + if let Ok(val) = env::var(var) { + for path in val.split(':') { + // check that LLVMEnzype*.so exists in this directory + if let Ok(entries) = std::fs::read_dir(path) { + for entry in entries.flatten() { + if let Some(filename) = entry.file_name().to_str() { + if filename.starts_with("LLVMEnzyme") && filename.ends_with(".so") { + return Ok(entry.path().to_str().unwrap().to_owned()); + } + } + } + } + } + } + } + Err(anyhow!( + "LLVMEnzyme*.so not found in any of: {:?}", + env_vars + )) + } + pub fn from_discrete_str(code: &str) -> Result { + let uid = Id::::new(); + let name = format!("diffsl_{}", uid); + let model = parse_ds_string(code).unwrap(); + let model = DiscreteModel::build(name.as_str(), &model) + .unwrap_or_else(|e| panic!("{}", e.as_error_message(code))); + let dir = env::temp_dir(); + let path = dir.join(name.clone()); + LlvmCompiler::from_discrete_model(&model, path.to_str().unwrap()) + } + + pub fn from_discrete_model(model: &DiscreteModel, out: &str) -> Result { + let number_of_states = *model.state().shape().first().unwrap_or(&1); + let input_names = model + .inputs() + .iter() + .map(|input| input.name().to_owned()) + .collect::>(); + let data_layout = DataLayout::new(model); + let context = Context::create(); + let number_of_parameters = input_names.iter().fold(0, |acc, name| { + acc + data_layout.get_data_length(name).unwrap() + }); + let number_of_outputs = data_layout.get_data_length("out").unwrap(); + let has_mass = model.lhs().is_some(); + LlvmCompilerTryBuilder { + data_layout, + number_of_states, + number_of_parameters, + number_of_outputs, + context, + has_mass, + output_base_filename: out.to_owned(), + data_builder: |context| { + let module = context.create_module(model.name()); + let real_type = context.f64_type(); + let real_type_str = "f64"; + + let mut codegen = CodeGen::new(model, context, module, real_type, real_type_str); + + let _set_u0 = codegen.compile_set_u0(model)?; + let _calc_stop = codegen.compile_calc_stop(model)?; + let _rhs = codegen.compile_rhs(model)?; + let _mass = codegen.compile_mass(model)?; + let _calc_out = codegen.compile_calc_out(model)?; + let _set_id = codegen.compile_set_id(model)?; + let _get_dims = codegen.compile_get_dims(model)?; + let _set_inputs = codegen.compile_set_inputs(model)?; + let _get_output = codegen.compile_get_tensor(model, "out")?; + + // optimise at -O2 no unrolling before giving to enzyme + let pass_options = PassBuilderOptions::create(); + //pass_options.set_verify_each(true); + //pass_options.set_debug_logging(true); + //pass_options.set_loop_interleaving(true); + pass_options.set_loop_vectorization(false); + pass_options.set_loop_slp_vectorization(false); + pass_options.set_loop_unrolling(false); + //pass_options.set_forget_all_scev_in_loop_unroll(true); + //pass_options.set_licm_mssa_opt_cap(1); + //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10); + //pass_options.set_call_graph_profile(true); + //pass_options.set_merge_functions(true); + + let initialization_config = &InitializationConfig::default(); + Target::initialize_all(initialization_config); + let triple = TargetMachine::get_default_triple(); + let target = Target::from_triple(&triple).unwrap(); + let machine = target + .create_target_machine( + &triple, + "generic", //TargetMachine::get_host_cpu_name().to_string().as_str(), + "", //TargetMachine::get_host_cpu_features().to_string().as_str(), + inkwell::OptimizationLevel::Default, + inkwell::targets::RelocMode::Default, + inkwell::targets::CodeModel::Default, + ) + .unwrap(); + + codegen + .module() + .run_passes("default", &machine, pass_options) + .unwrap(); + + let _rhs_grad = codegen.compile_gradient( + _rhs, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Dup, + CompileGradientArgType::Dup, + CompileGradientArgType::DupNoNeed, + ], + )?; + let _set_inputs_grad = codegen.compile_gradient( + _set_inputs, + &[CompileGradientArgType::Dup, CompileGradientArgType::Dup], + )?; + let _calc_out_grad = codegen.compile_gradient( + _calc_out, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Dup, + CompileGradientArgType::Dup, + ], + )?; + let _set_u0_grad = codegen.compile_gradient( + _set_u0, + &[CompileGradientArgType::Dup, CompileGradientArgType::Dup], + )?; + + let ee = codegen + .module() + .create_jit_execution_engine(OptimizationLevel::Aggressive) + .map_err(|e| anyhow::anyhow!("Error creating execution engine: {:?}", e))?; + + let set_u0 = LlvmCompiler::jit("set_u0", &ee)?; + let rhs = LlvmCompiler::jit("rhs", &ee)?; + let mass = LlvmCompiler::jit("mass", &ee)?; + let calc_stop = LlvmCompiler::jit("calc_stop", &ee)?; + let calc_out = LlvmCompiler::jit("calc_out", &ee)?; + let set_id = LlvmCompiler::jit("set_id", &ee)?; + let get_dims = LlvmCompiler::jit("get_dims", &ee)?; + let set_inputs = LlvmCompiler::jit("set_inputs", &ee)?; + let get_out = LlvmCompiler::jit("get_out", &ee)?; + + let set_inputs_grad = LlvmCompiler::jit("set_inputs_grad", &ee)?; + let calc_out_grad = LlvmCompiler::jit("calc_out_grad", &ee)?; + let rhs_grad = LlvmCompiler::jit("rhs_grad", &ee)?; + let set_u0_grad = LlvmCompiler::jit("set_u0_grad", &ee)?; + + let data = CompilerData { + codegen, + jit_functions: JitFunctions { + set_u0, + rhs, + mass, + calc_out, + set_id, + get_dims, + set_inputs, + get_out, + calc_stop, + }, + jit_grad_functions: JitGradFunctions { + set_u0_grad, + rhs_grad, + calc_out_grad, + set_inputs_grad, + }, + }; + Ok(data) + }, + } + .try_build() + } + + pub fn compile(&self, standalone: bool, wasm: bool) -> Result<()> { + let opt_name = LlvmCompiler::find_opt()?; + let clang_name = LlvmCompiler::find_clang()?; + let enzyme_lib = LlvmCompiler::find_enzyme_lib()?; + let out = self.borrow_output_base_filename(); + let object_filename = LlvmCompiler::get_object_filename(out); + let bitcodefilename = LlvmCompiler::get_bitcode_filename(out); + let mut command = Command::new(clang_name); + command + .arg(bitcodefilename.as_str()) + .arg("-c") + .arg(format!("-fplugin={}", enzyme_lib)) + .arg("-o") + .arg(object_filename.as_str()); + + if wasm { + command.arg("-target").arg("wasm32-unknown-emscripten"); + } + + let output = command.output().unwrap(); + + if let Some(code) = output.status.code() { + if code != 0 { + println!("{}", String::from_utf8_lossy(&output.stderr)); + return Err(anyhow!("{} returned error code {}", opt_name, code)); + } + } + + // link the object file and our runtime library + let mut command = if wasm { + let emcc_varients = ["emcc"]; + let command_name = find_executable(&emcc_varients)?; + let exported_functions = vec![ + "Vector_destroy", + "Vector_create", + "Vector_create_with_capacity", + "Vector_push", + "Options_destroy", + "Options_create", + "Sundials_destroy", + "Sundials_create", + "Sundials_init", + "Sundials_solve", + ]; + let mut linked_files = vec![ + "libdiffeq_runtime_lib.a", + "libsundials_idas.a", + "libsundials_sunlinsolklu.a", + "libklu.a", + "libamd.a", + "libcolamd.a", + "libbtf.a", + "libsuitesparseconfig.a", + "libsundials_sunmatrixsparse.a", + "libargparse.a", + ]; + if standalone { + linked_files.push("libdiffeq_runtime_wasm.a"); + } + let linked_files = linked_files; + let runtime_path = find_runtime_path(&linked_files)?; + let mut command = Command::new(command_name); + command.arg("-o").arg(out).arg(object_filename.as_str()); + for file in linked_files { + command.arg(Path::new(runtime_path.as_str()).join(file)); + } + if !standalone { + let exported_functions = exported_functions + .into_iter() + .map(|s| format!("_{}", s)) + .collect::>() + .join(","); + command + .arg("-s") + .arg(format!("EXPORTED_FUNCTIONS={}", exported_functions)); + command.arg("--no-entry"); + } + command + } else { + let mut command = Command::new(clang_name); + command.arg("-o").arg(out).arg(object_filename.as_str()); + if standalone { + command.arg("-ldiffeq_runtime"); + } else { + command.arg("-ldiffeq_runtime_lib"); + } + command + }; + + let output = command.output(); + + let output = match output { + Ok(output) => output, + Err(e) => { + let args = command + .get_args() + .map(|s| s.to_str().unwrap()) + .collect::>() + .join(" "); + println!( + "{} {}", + command.get_program().to_os_string().to_str().unwrap(), + args + ); + return Err(anyhow!("Error linking in runtime: {}", e)); + } + }; + + if let Some(code) = output.status.code() { + if code != 0 { + let args = command + .get_args() + .map(|s| s.to_str().unwrap()) + .collect::>() + .join(" "); + println!( + "{} {}", + command.get_program().to_os_string().to_str().unwrap(), + args + ); + println!("{}", String::from_utf8_lossy(&output.stderr)); + return Err(anyhow!( + "Error linking in runtime, returned error code {}", + code + )); + } + } + Ok(()) + } + + fn get_bitcode_filename(out: &str) -> String { + format!("{}.bc", out) + } + + fn get_object_filename(out: &str) -> String { + format!("{}.o", out) + } + + fn jit<'ctx, T>(name: &str, ee: &ExecutionEngine<'ctx>) -> Result> + where + T: UnsafeFunctionPointer, + { + let maybe_fn = unsafe { ee.get_function::(name) }; + match maybe_fn { + Ok(f) => Ok(f), + Err(err) => Err(anyhow!("Error during jit for {}: {}", name, err)), + } + } + + pub fn get_tensor_data<'a>(&self, name: &str, data: &'a [f64]) -> Option<&'a [f64]> { + let index = self.borrow_data_layout().get_data_index(name)?; + let nnz = self.borrow_data_layout().get_data_length(name)?; + Some(&data[index..index + nnz]) + } + + pub fn set_u0(&self, yy: &mut [f64], data: &mut [f64]) { + let number_of_states = *self.borrow_number_of_states(); + if yy.len() != number_of_states { + panic!("Expected {} states, got {}", number_of_states, yy.len()); + } + self.with_data(|compiler| { + let yy_ptr = yy.as_mut_ptr(); + let data_ptr = data.as_mut_ptr(); + unsafe { + compiler.jit_functions.set_u0.call(data_ptr, yy_ptr); + } + }); + } + + pub fn set_u0_grad( + &self, + yy: &mut [f64], + dyy: &mut [f64], + data: &mut [f64], + ddata: &mut [f64], + ) { + let number_of_states = *self.borrow_number_of_states(); + if yy.len() != number_of_states { + panic!("Expected {} states, got {}", number_of_states, yy.len()); + } + if dyy.len() != number_of_states { + panic!( + "Expected {} states for dyy, got {}", + number_of_states, + dyy.len() + ); + } + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + if ddata.len() != self.data_len() { + panic!( + "Expected {} data for ddata, got {}", + self.data_len(), + ddata.len() + ); + } + self.with_data(|compiler| { + let yy_ptr = yy.as_mut_ptr(); + let data_ptr = data.as_mut_ptr(); + let dyy_ptr = dyy.as_mut_ptr(); + let ddata_ptr = ddata.as_mut_ptr(); + unsafe { + compiler + .jit_grad_functions + .set_u0_grad + .call(data_ptr, ddata_ptr, yy_ptr, dyy_ptr); + } + }); + } + + pub fn calc_stop(&self, t: f64, yy: &[f64], data: &mut [f64], stop: &mut [f64]) { + let (n_states, _, _, n_data, n_stop) = self.get_dims(); + if yy.len() != n_states { + panic!("Expected {} states, got {}", n_states, yy.len()); + } + if data.len() != n_data { + panic!("Expected {} data, got {}", n_data, data.len()); + } + if stop.len() != n_stop { + panic!("Expected {} stop, got {}", n_stop, stop.len()); + } + self.with_data(|compiler| { + let yy_ptr = yy.as_ptr(); + let data_ptr = data.as_mut_ptr(); + let stop_ptr = stop.as_mut_ptr(); + unsafe { + compiler + .jit_functions + .calc_stop + .call(t, yy_ptr, data_ptr, stop_ptr); + } + }); + } + + pub fn rhs(&self, t: f64, yy: &[f64], data: &mut [f64], rr: &mut [f64]) { + let number_of_states = *self.borrow_number_of_states(); + if yy.len() != number_of_states { + panic!("Expected {} states, got {}", number_of_states, yy.len()); + } + if rr.len() != number_of_states { + panic!( + "Expected {} residual states, got {}", + number_of_states, + rr.len() + ); + } + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + self.with_data(|compiler| { + let yy_ptr = yy.as_ptr(); + let rr_ptr = rr.as_mut_ptr(); + let data_ptr = data.as_mut_ptr(); + unsafe { + compiler.jit_functions.rhs.call(t, yy_ptr, data_ptr, rr_ptr); + } + }); + } + + pub fn has_mass(&self) -> bool { + *self.borrow_has_mass() + } + + pub fn mass(&self, t: f64, yp: &[f64], data: &mut [f64], rr: &mut [f64]) { + if !self.borrow_has_mass() { + panic!("Model does not have a mass function"); + } + let number_of_states = *self.borrow_number_of_states(); + if yp.len() != number_of_states { + panic!("Expected {} states, got {}", number_of_states, yp.len()); + } + if rr.len() != number_of_states { + panic!( + "Expected {} residual states, got {}", + number_of_states, + rr.len() + ); + } + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + self.with_data(|compiler| { + let yp_ptr = yp.as_ptr(); + let rr_ptr = rr.as_mut_ptr(); + let data_ptr = data.as_mut_ptr(); + unsafe { + compiler + .jit_functions + .mass + .call(t, yp_ptr, data_ptr, rr_ptr); + } + }); + } + + pub fn data_len(&self) -> usize { + self.with(|compiler| compiler.data_layout.data().len()) + } + + pub fn get_new_data(&self) -> Vec { + vec![0.; self.data_len()] + } + + #[allow(clippy::too_many_arguments)] + pub fn rhs_grad( + &self, + t: f64, + yy: &[f64], + dyy: &[f64], + data: &mut [f64], + ddata: &mut [f64], + rr: &mut [f64], + drr: &mut [f64], + ) { + let number_of_states = *self.borrow_number_of_states(); + if yy.len() != number_of_states { + panic!("Expected {} states, got {}", number_of_states, yy.len()); + } + if rr.len() != number_of_states { + panic!( + "Expected {} residual states, got {}", + number_of_states, + rr.len() + ); + } + if dyy.len() != number_of_states { + panic!( + "Expected {} states for dyy, got {}", + number_of_states, + dyy.len() + ); + } + if drr.len() != number_of_states { + panic!( + "Expected {} residual states for drr, got {}", + number_of_states, + drr.len() + ); + } + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + if ddata.len() != self.data_len() { + panic!( + "Expected {} data for ddata, got {}", + self.data_len(), + ddata.len() + ); + } + self.with_data(|compiler| { + let yy_ptr = yy.as_ptr(); + let rr_ptr = rr.as_mut_ptr(); + let dyy_ptr = dyy.as_ptr(); + let drr_ptr = drr.as_mut_ptr(); + let data_ptr = data.as_mut_ptr(); + let ddata_ptr = ddata.as_mut_ptr(); + unsafe { + compiler + .jit_grad_functions + .rhs_grad + .call(t, yy_ptr, dyy_ptr, data_ptr, ddata_ptr, rr_ptr, drr_ptr); + } + }); + } + + pub fn calc_out(&self, t: f64, yy: &[f64], data: &mut [f64]) { + let number_of_states = *self.borrow_number_of_states(); + if yy.len() != *self.borrow_number_of_states() { + panic!("Expected {} states, got {}", number_of_states, yy.len()); + } + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + self.with_data(|compiler| { + let yy_ptr = yy.as_ptr(); + let data_ptr = data.as_mut_ptr(); + unsafe { + compiler.jit_functions.calc_out.call(t, yy_ptr, data_ptr); + } + }); + } + + pub fn calc_out_grad( + &self, + t: f64, + yy: &[f64], + dyy: &[f64], + data: &mut [f64], + ddata: &mut [f64], + ) { + let number_of_states = *self.borrow_number_of_states(); + if yy.len() != *self.borrow_number_of_states() { + panic!("Expected {} states, got {}", number_of_states, yy.len()); + } + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + if dyy.len() != *self.borrow_number_of_states() { + panic!( + "Expected {} states for dyy, got {}", + number_of_states, + dyy.len() + ); + } + if ddata.len() != self.data_len() { + panic!( + "Expected {} data for ddata, got {}", + self.data_len(), + ddata.len() + ); + } + self.with_data(|compiler| { + let yy_ptr = yy.as_ptr(); + let data_ptr = data.as_mut_ptr(); + let dyy_ptr = dyy.as_ptr(); + let ddata_ptr = ddata.as_mut_ptr(); + unsafe { + compiler + .jit_grad_functions + .calc_out_grad + .call(t, yy_ptr, dyy_ptr, data_ptr, ddata_ptr); + } + }); + } + + /// Get various dimensions of the model + /// + /// # Returns + /// + /// A tuple of the form `(n_states, n_inputs, n_outputs, n_data, n_stop)` + pub fn get_dims(&self) -> (usize, usize, usize, usize, usize) { + let mut n_states = 0u32; + let mut n_inputs = 0u32; + let mut n_outputs = 0u32; + let mut n_data = 0u32; + let mut n_stop = 0u32; + self.with(|compiler| unsafe { + compiler.data.jit_functions.get_dims.call( + &mut n_states, + &mut n_inputs, + &mut n_outputs, + &mut n_data, + &mut n_stop, + ); + }); + ( + n_states as usize, + n_inputs as usize, + n_outputs as usize, + n_data as usize, + n_stop as usize, + ) + } + + pub fn set_inputs(&self, inputs: &[f64], data: &mut [f64]) { + let (_, n_inputs, _, _, _) = self.get_dims(); + if n_inputs != inputs.len() { + panic!("Expected {} inputs, got {}", n_inputs, inputs.len()); + } + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + self.with_data(|compiler| { + let data_ptr = data.as_mut_ptr(); + unsafe { + compiler + .jit_functions + .set_inputs + .call(inputs.as_ptr(), data_ptr); + } + }); + } + + pub fn set_inputs_grad( + &self, + inputs: &[f64], + dinputs: &[f64], + data: &mut [f64], + ddata: &mut [f64], + ) { + let (_, n_inputs, _, _, _) = self.get_dims(); + if n_inputs != inputs.len() { + panic!("Expected {} inputs, got {}", n_inputs, inputs.len()); + } + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + if dinputs.len() != n_inputs { + panic!( + "Expected {} inputs for dinputs, got {}", + n_inputs, + dinputs.len() + ); + } + if ddata.len() != self.data_len() { + panic!( + "Expected {} data for ddata, got {}", + self.data_len(), + ddata.len() + ); + } + self.with_data(|compiler| { + let data_ptr = data.as_mut_ptr(); + let ddata_ptr = ddata.as_mut_ptr(); + let dinputs_ptr = dinputs.as_ptr(); + unsafe { + compiler.jit_grad_functions.set_inputs_grad.call( + inputs.as_ptr(), + dinputs_ptr, + data_ptr, + ddata_ptr, + ); + } + }); + } + + pub fn get_out(&self, data: &[f64]) -> &[f64] { + if data.len() != self.data_len() { + panic!("Expected {} data, got {}", self.data_len(), data.len()); + } + let (_, _, n_outputs, _, _) = self.get_dims(); + let mut tensor_data_ptr: *mut f64 = std::ptr::null_mut(); + let mut tensor_data_len = 0u32; + let tensor_data_ptr_ptr: *mut *mut f64 = &mut tensor_data_ptr; + let tensor_data_len_ptr: *mut u32 = &mut tensor_data_len; + self.with(|compiler| { + let data_ptr = data.as_ptr(); + unsafe { + compiler.data.jit_functions.get_out.call( + data_ptr, + tensor_data_ptr_ptr, + tensor_data_len_ptr, + ); + } + }); + assert!(tensor_data_len as usize == n_outputs); + unsafe { std::slice::from_raw_parts(tensor_data_ptr, tensor_data_len as usize) } + } + + pub fn set_id(&self, id: &mut [f64]) { + let (n_states, _, _, _, _) = self.get_dims(); + if n_states != id.len() { + panic!("Expected {} states, got {}", n_states, id.len()); + } + self.with_data(|compiler| { + unsafe { + compiler.jit_functions.set_id.call(id.as_mut_ptr()); + }; + }); + } + + fn get_native_machine() -> Result { + Target::initialize_native(&InitializationConfig::default()) + .map_err(|e| anyhow!("{}", e))?; + let opt = OptimizationLevel::Default; + let reloc = RelocMode::Default; + let model = CodeModel::Default; + let target_triple = TargetMachine::get_default_triple(); + let target = Target::from_triple(&target_triple).unwrap(); + let target_machine = target + .create_target_machine( + &target_triple, + TargetMachine::get_host_cpu_name().to_str().unwrap(), + TargetMachine::get_host_cpu_features().to_str().unwrap(), + opt, + reloc, + model, + ) + .unwrap(); + Ok(target_machine) + } + + fn get_wasm_machine() -> Result { + Target::initialize_webassembly(&InitializationConfig::default()); + let opt = OptimizationLevel::Default; + let reloc = RelocMode::Default; + let model = CodeModel::Default; + let target_triple = TargetTriple::create("wasm32-unknown-emscripten"); + let target = Target::from_triple(&target_triple).unwrap(); + let target_machine = target + .create_target_machine(&target_triple, "generic", "", opt, reloc, model) + .unwrap(); + Ok(target_machine) + } + + pub fn write_bitcode_to_path(&self, path: &Path) -> Result<()> { + self.with_data(|data| { + let result = data.codegen.module().write_bitcode_to_path(path); + if result { + Ok(()) + } else { + Err(anyhow!("Error writing bitcode to path")) + } + }) + } + + pub fn write_object_file(&self, path: &Path) -> Result<()> { + let target_machine = LlvmCompiler::get_native_machine()?; + self.with_data(|data| { + target_machine + .write_to_file(data.codegen.module(), FileType::Object, path) + .map_err(|e| anyhow::anyhow!("Error writing object file: {:?}", e)) + }) + } + + pub fn write_wasm_object_file(&self, path: &Path) -> Result<()> { + let target_machine = LlvmCompiler::get_wasm_machine()?; + self.with_data(|data| { + target_machine + .write_to_file(data.codegen.module(), FileType::Object, path) + .map_err(|e| anyhow::anyhow!("Error writing object file: {:?}", e)) + }) + } + + pub fn number_of_states(&self) -> usize { + *self.borrow_number_of_states() + } + pub fn number_of_parameters(&self) -> usize { + *self.borrow_number_of_parameters() + } + + pub fn number_of_outputs(&self) -> usize { + *self.borrow_number_of_outputs() + } +} + +#[cfg(test)] +mod tests { + use crate::{ + continuous::ModelInfo, + parser::{parse_ds_string, parse_ms_string}, + }; + use approx::assert_relative_eq; + + use super::*; + + #[test] + fn test_object_file() { + let text = " + model logistic_growth(r -> NonNegative, k -> NonNegative, y(t), z(t)) { + dot(y) = r * y * (1 - y / k) + y(0) = 1.0 + z = 2 * y + } + "; + let models = parse_ms_string(text).unwrap(); + let model_info = ModelInfo::build("logistic_growth", &models).unwrap(); + assert_eq!(model_info.errors.len(), 0); + let discrete_model = DiscreteModel::from(&model_info); + let object = + LlvmCompiler::from_discrete_model(&discrete_model, "test_output/compiler_test_object_file") + .unwrap(); + let path = Path::new("main.o"); + object.write_object_file(path).unwrap(); + } + + #[test] + fn test_from_discrete_str() { + let text = " + u { y = 1 } + F { -y } + out { y } + "; + let compiler = LlvmCompiler::from_discrete_str(text).unwrap(); + let mut u0 = vec![0.]; + let mut res = vec![0.]; + let mut data = compiler.get_new_data(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + assert_relative_eq!(u0.as_slice(), vec![1.].as_slice()); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + assert_relative_eq!(res.as_slice(), vec![-1.].as_slice()); + } + + #[test] + fn test_stop() { + let full_text = " + u_i { + y = 1, + } + dudt_i { + dydt = 0, + } + M_i { + dydt, + } + F_i { + y * (1 - y), + } + stop_i { + y - 0.5, + } + out { + y, + } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = + LlvmCompiler::from_discrete_model(&discrete_model, "test_output/compiler_test_stop") + .unwrap(); + let mut u0 = vec![1.]; + let mut res = vec![0.]; + let mut stop = vec![0.]; + let mut data = compiler.get_new_data(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + compiler.calc_stop(0., u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice()); + assert_relative_eq!(stop[0], 0.5); + assert_eq!(stop.len(), 1); + } + + fn tensor_test_common(text: &str, tmp_loc: &str, tensor_name: &str) -> Vec> { + let full_text = format!( + " + {} + ", + text + ); + let model = parse_ds_string(full_text.as_str()).unwrap(); + let discrete_model = match DiscreteModel::build("$name", &model) { + Ok(model) => model, + Err(e) => { + panic!("{}", e.as_error_message(full_text.as_str())); + } + }; + let compiler = LlvmCompiler::from_discrete_model(&discrete_model, tmp_loc).unwrap(); + let mut u0 = vec![1.]; + let mut res = vec![0.]; + let mut data = compiler.get_new_data(); + let mut grad_data = Vec::new(); + let (_n_states, n_inputs, _n_outputs, _n_data, _n_stop) = compiler.get_dims(); + for _ in 0..n_inputs { + grad_data.push(compiler.get_new_data()); + } + let mut results = Vec::new(); + let inputs = vec![1.; n_inputs]; + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + compiler.calc_out(0., u0.as_slice(), data.as_mut_slice()); + results.push( + compiler + .get_tensor_data(tensor_name, data.as_slice()) + .unwrap() + .to_vec(), + ); + for i in 0..n_inputs { + let mut dinputs = vec![0.; n_inputs]; + dinputs[i] = 1.0; + let mut ddata = compiler.get_new_data(); + let mut du0 = vec![0.]; + let mut dres = vec![0.]; + compiler.set_inputs_grad( + inputs.as_slice(), + dinputs.as_slice(), + grad_data[i].as_mut_slice(), + ddata.as_mut_slice(), + ); + compiler.set_u0_grad( + u0.as_mut_slice(), + du0.as_mut_slice(), + grad_data[i].as_mut_slice(), + ddata.as_mut_slice(), + ); + compiler.rhs_grad( + 0., + u0.as_slice(), + du0.as_slice(), + grad_data[i].as_mut_slice(), + ddata.as_mut_slice(), + res.as_mut_slice(), + dres.as_mut_slice(), + ); + compiler.calc_out_grad( + 0., + u0.as_slice(), + du0.as_slice(), + grad_data[i].as_mut_slice(), + ddata.as_mut_slice(), + ); + results.push( + compiler + .get_tensor_data(tensor_name, ddata.as_slice()) + .unwrap() + .to_vec(), + ); + } + results + } + + macro_rules! tensor_test { + ($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr,)*) => { + $( + #[test] + fn $name() { + let full_text = format!(" + {} + u_i {{ + y = 1, + }} + F_i {{ + y, + }} + out_i {{ + y, + }} + ", $text); + let tmp_loc = format!("test_output/compiler_tensor_test_{}", stringify!($name)); + let results = tensor_test_common(full_text.as_str(), tmp_loc.as_str(), $tensor_name); + assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); + } + )* + } + } + + tensor_test! { + heaviside_function0: "r { heaviside(-0.1) }" expect "r" vec![0.0], + heaviside_function1: "r { heaviside(0.0) }" expect "r" vec![1.0], + exp_function: "r { exp(2) }" expect "r" vec![f64::exp(2.0)], + pow_function: "r { pow(4.3245, 0.5) }" expect "r" vec![f64::powf(4.3245, 0.5)], + arcsinh_function: "r { arcsinh(0.5) }" expect "r" vec![f64::asinh(0.5)], + arccosh_function: "r { arccosh(2) }" expect "r" vec![f64::acosh(2.0)], + tanh_function: "r { tanh(0.5) }" expect "r" vec![f64::tanh(0.5)], + sinh_function: "r { sinh(0.5) }" expect "r" vec![f64::sinh(0.5)], + cosh_function: "r { cosh(0.5) }" expect "r" vec![f64::cosh(0.5)], + exp_function_time: "r { exp(t) }" expect "r" vec![f64::exp(0.0)], + min_function: "r { min(2, 3) }" expect "r" vec![2.0], + max_function: "r { max(2, 3) }" expect "r" vec![3.0], + sigmoid_function: "r { sigmoid(0.1) }" expect "r" vec![1.0 / (1.0 + f64::exp(-0.1))], + scalar: "r {2}" expect "r" vec![2.0,], + constant: "r_i {2, 3}" expect "r" vec![2., 3.], + expression: "r_i {2 + 3, 3 * 2, arcsinh(1.2 + 1.0 / max(1.2, 1.0) * 2.0 + tanh(2.0))}" expect "r" vec![5., 6., f64::asinh(1.2 + 1.0 / f64::max(1.2, 1.0) * 2.0 + f64::tanh(2.0))], + pybamm_expression: " + constant0_i { (0:19): 0.0, (19:20): 0.0006810238128045524,} + constant1_i { (0:19): 0.0, (19:20): -0.0011634665332403958,} + constant2_ij { (0,18): -25608.96286546366, (0,19): 76826.88859639116,} + constant3_ij {(0,18): -0.4999999999999983, (0,19): 1.4999999999999984,} + constant4_ij {(0,18): -0.4999999999999983, (0,19): 1.4999999999999982,} + constant7_ij { (0,18): -12491.630996921805, (0,19): 37474.892990765504,} + xaveragednegativeparticleconcentrationmolm3_i { 0.245049, 0.244694, 0.243985, 0.242921, 0.241503, 0.239730, 0.237603, 0.235121, 0.232284, 0.229093, 0.225547, 0.221647, 0.217392, 0.212783, 0.207819, 0.202500, 0.196827, 0.190799, 0.184417, 0.177680, } + xaveragedpositiveparticleconcentrationmolm3_i { 0.939986, 0.940066, 0.940228, 0.940471, 0.940795, 0.941200, 0.941685, 0.942252, 0.942899, 0.943628, 0.944437, 0.945328, 0.946299, 0.947351, 0.948485, 0.949699, 0.950994, 0.952370, 0.953827, 0.955365, } + varying2_i {(constant2_ij * xaveragedpositiveparticleconcentrationmolm3_j),} + varying3_i {(constant4_ij * xaveragedpositiveparticleconcentrationmolm3_j),} + varying4_i {(constant7_ij * xaveragednegativeparticleconcentrationmolm3_j),} + varying5_i {(constant3_ij * xaveragednegativeparticleconcentrationmolm3_j),} + r_i {(((0.05138515824298745 * arcsinh((-0.7999999999999998 / ((1.8973665961010275e-05 * pow(max(min(varying2_i, 51217.92521874824), 0.000512179257309275), 0.5)) * pow((51217.9257309275 - max(min(varying2_i, 51217.92521874824), 0.000512179257309275)), 0.5))))) + (((((((2.16216 + (0.07645 * tanh((30.834 - (57.858397200000006 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (2.1581 * tanh((52.294 - (53.412228 * max(min(varying3_i, 0.9999999999), 1e-10)))))) - (0.14169 * tanh((11.0923 - (21.0852666 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (0.2051 * tanh((1.4684 - (5.829105600000001 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (0.2531 * tanh((4.291641337386018 - (8.069908814589667 * max(min(varying3_i, 0.9999999999), 1e-10)))))) - (0.02167 * tanh((-87.5 + (177.0 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (1e-06 * ((1.0 / max(min(varying3_i, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(varying3_i, 0.9999999999), 1e-10))))))) - ((0.05138515824298745 * arcsinh((0.6666666666666666 / ((0.0006324555320336759 * pow(max(min(varying4_i, 24983.261744011077), 0.000249832619938437), 0.5)) * pow((24983.2619938437 - max(min(varying4_i, 24983.261744011077), 0.000249832619938437)), 0.5))))) + ((((((((((0.194 + (1.5 * exp((-120.0 * max(min(varying5_i, 0.9999999999), 1e-10))))) + (0.0351 * tanh((-3.44578313253012 + (12.048192771084336 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.0045 * tanh((-7.1344537815126055 + (8.403361344537815 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.035 * tanh((-18.466 + (20.0 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.0147 * tanh((-14.705882352941176 + (29.41176470588235 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.102 * tanh((-1.3661971830985917 + (7.042253521126761 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.022 * tanh((-54.8780487804878 + (60.975609756097555 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.011 * tanh((-5.486725663716814 + (44.24778761061947 * max(min(varying5_i, 0.9999999999), 1e-10)))))) + (0.0155 * tanh((-3.6206896551724133 + (34.48275862068965 * max(min(varying5_i, 0.9999999999), 1e-10)))))) + (1e-06 * ((1.0 / max(min(varying5_i, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(varying5_i, 0.9999999999), 1e-10)))))))),} + " expect "r" vec![3.191533267340602], + pybamm_subexpression: " + constant2_ij { (0,18): -25608.96286546366, (0,19): 76826.88859639116,} + st_i { (0:20): xaveragednegativeparticleconcentrationmolm3 = 0.8000000000000016, (20:40): xaveragedpositiveparticleconcentrationmolm3 = 0.6000000000000001, } + varying2_i {(constant2_ij * xaveragedpositiveparticleconcentrationmolm3_j),} + " expect "varying2" vec![-25608.96286546366 * 0.6000000000000001 + 76826.88859639116 * 0.6000000000000001], + pybamm_subexpression2: " + constant4_ij {(0,18): -0.4999999999999983, (0,19): 1.4999999999999982,} + st_i { (0:20): xaveragednegativeparticleconcentrationmolm3 = 0.8000000000000016, (20:40): xaveragedpositiveparticleconcentrationmolm3 = 0.6000000000000001, } + varying3_i {(constant4_ij * xaveragedpositiveparticleconcentrationmolm3_j),} + " expect "varying3" vec![-0.4999999999999983 * 0.6000000000000001 + 1.4999999999999982 * 0.6000000000000001], + pybamm_subexpression3: " + constant7_ij { (0,18): -12491.630996921805, (0,19): 37474.892990765504,} + st_i { (0:20): xaveragednegativeparticleconcentrationmolm3 = 0.8000000000000016, (20:40): xaveragedpositiveparticleconcentrationmolm3 = 0.6000000000000001, } + varying4_i {(constant7_ij * xaveragednegativeparticleconcentrationmolm3_j),} + " expect "varying4" vec![-12491.630996921805 * 0.8000000000000016 + 37474.892990765504 * 0.8000000000000016], + pybamm_subexpression4: " + varying2_i {30730.7554386,} + varying3_i {0.6,} + varying4_i {19986.6095951,} + varying5_i {0.8,} + r_i {(((0.05138515824298745 * arcsinh((-0.7999999999999998 / ((1.8973665961010275e-05 * pow(max(min(varying2_i, 51217.92521874824), 0.000512179257309275), 0.5)) * pow((51217.9257309275 - max(min(varying2_i, 51217.92521874824), 0.000512179257309275)), 0.5))))) + (((((((2.16216 + (0.07645 * tanh((30.834 - (57.858397200000006 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (2.1581 * tanh((52.294 - (53.412228 * max(min(varying3_i, 0.9999999999), 1e-10)))))) - (0.14169 * tanh((11.0923 - (21.0852666 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (0.2051 * tanh((1.4684 - (5.829105600000001 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (0.2531 * tanh((4.291641337386018 - (8.069908814589667 * max(min(varying3_i, 0.9999999999), 1e-10)))))) - (0.02167 * tanh((-87.5 + (177.0 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (1e-06 * ((1.0 / max(min(varying3_i, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(varying3_i, 0.9999999999), 1e-10))))))) - ((0.05138515824298745 * arcsinh((0.6666666666666666 / ((0.0006324555320336759 * pow(max(min(varying4_i, 24983.261744011077), 0.000249832619938437), 0.5)) * pow((24983.2619938437 - max(min(varying4_i, 24983.261744011077), 0.000249832619938437)), 0.5))))) + ((((((((((0.194 + (1.5 * exp((-120.0 * max(min(varying5_i, 0.9999999999), 1e-10))))) + (0.0351 * tanh((-3.44578313253012 + (12.048192771084336 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.0045 * tanh((-7.1344537815126055 + (8.403361344537815 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.035 * tanh((-18.466 + (20.0 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.0147 * tanh((-14.705882352941176 + (29.41176470588235 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.102 * tanh((-1.3661971830985917 + (7.042253521126761 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.022 * tanh((-54.8780487804878 + (60.975609756097555 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.011 * tanh((-5.486725663716814 + (44.24778761061947 * max(min(varying5_i, 0.9999999999), 1e-10)))))) + (0.0155 * tanh((-3.6206896551724133 + (34.48275862068965 * max(min(varying5_i, 0.9999999999), 1e-10)))))) + (1e-06 * ((1.0 / max(min(varying5_i, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(varying5_i, 0.9999999999), 1e-10)))))))),} + " expect "r" vec![(((0.05138515824298745 * f64::asinh(-0.7999999999999998 / ((1.897_366_596_101_027_5e-5 * f64::powf(f64::max(f64::min(30730.7554386, 51217.92521874824), 0.000512179257309275), 0.5)) * f64::powf(51217.9257309275 - f64::max(f64::min(30730.7554386, 51217.92521874824), 0.000512179257309275), 0.5)))) + (((((((2.16216 + (0.07645 * f64::tanh(30.834 - (57.858397200000006 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) + (2.1581 * f64::tanh(52.294 - (53.412228 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) - (0.14169 * f64::tanh(11.0923 - (21.0852666 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) + (0.2051 * f64::tanh(1.4684 - (5.829105600000001 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) + (0.2531 * f64::tanh(4.291641337386018 - (8.069908814589667 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) - (0.02167 * f64::tanh(-87.5 + (177.0 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) + (1e-06 * ((1.0 / f64::max(f64::min(0.6, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + f64::max(f64::min(0.6, 0.9999999999), 1e-10))))))) - ((0.05138515824298745 * f64::asinh(0.6666666666666666 / ((0.0006324555320336759 * f64::powf(f64::max(f64::min(19986.6095951, 24983.261744011077), 0.000249832619938437), 0.5)) * f64::powf(24983.2619938437 - f64::max(f64::min(19986.6095951, 24983.261744011077), 0.000249832619938437), 0.5)))) + ((((((((((0.194 + (1.5 * f64::exp(-120.0 * f64::max(f64::min(0.8, 0.9999999999), 1e-10)))) + (0.0351 * f64::tanh(-3.44578313253012 + (12.048192771084336 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.0045 * f64::tanh(-7.1344537815126055 + (8.403361344537815 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.035 * f64::tanh(-18.466 + (20.0 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.0147 * f64::tanh(-14.705882352941176 + (29.41176470588235 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.102 * f64::tanh(-1.3661971830985917 + (7.042253521126761 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.022 * f64::tanh(-54.8780487804878 + (60.975609756097555 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.011 * f64::tanh(-5.486725663716814 + (44.24778761061947 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) + (0.0155 * f64::tanh(-3.6206896551724133 + (34.48275862068965 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) + (1e-06 * ((1.0 / f64::max(f64::min(0.8, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + f64::max(f64::min(0.8, 0.9999999999), 1e-10))))))))], + pybamm_subexpression5: "r_i { (1.0 / max(min(0.6, 0.9999999999), 1e-10)),}" expect "r" vec![1.0 / f64::max(f64::min(0.6, 0.9999999999), 1e-10)], + pybamm_subexpression6: "r_i { arcsinh(1.8973665961010275e-05), }" expect "r" vec![f64::asinh(1.897_366_596_101_027_5e-5)], + pybamm_subexpression7: "r_i { (1.5 * exp(-120.0 * max(min(0.8, 0.9999999999), 1e-10))), }" expect "r" vec![1.5 * f64::exp(-120.0 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))], + pybamm_subexpression8: "r_i { (0.07645 * tanh(30.834 - (57.858397200000006 * max(min(0.6, 0.9999999999), 1e-10)))), }" expect "r" vec![0.07645 * f64::tanh(30.834 - (57.858397200000006 * f64::max(f64::min(0.6, 0.9999999999), 1e-10)))], + pybamm_subexpression9: "r_i { (1e-06 * ((1.0 / max(min(0.8, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(0.8, 0.9999999999), 1e-10))))), }" expect "r" vec![1e-06 * ((1.0 / f64::max(f64::min(0.8, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + f64::max(f64::min(0.8, 0.9999999999), 1e-10))))], + pybamm_subexpression10: "r_i { (1.0 / (-1.0 + max(min(0.8, 0.9999999999), 1e-10))), }" expect "r" vec![1.0 / (-1.0 + f64::max(f64::min(0.8, 0.9999999999), 1e-10))], + unary_negate_in_expr: "r_i { 1.0 / (-1.0 + 1.1) }" expect "r" vec![1.0 / (-1.0 + 1.1)], + derived: "r_i {2, 3} k_i { 2 * r_i }" expect "k" vec![4., 6.], + concatenate: "r_i {2, 3} k_i { r_i, 2 * r_i }" expect "k" vec![2., 3., 4., 6.], + ones_matrix_dense: "I_ij { (0:2, 0:2): 1 }" expect "I" vec![1., 1., 1., 1.], + dense_matrix: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 }" expect "A" vec![1., 2., 3., 4.], + dense_vector: "x_i { (0:4): 1, (4:5): 2 }" expect "x" vec![1., 1., 1., 1., 2.], + identity_matrix_diagonal: "I_ij { (0..2, 0..2): 1 }" expect "I" vec![1., 1.], + concatenate_diagonal: "A_ij { (0..2, 0..2): 1 } B_ij { (0:2, 0:2): A_ij, (2:4, 2:4): A_ij }" expect "B" vec![1., 1., 1., 1.], + identity_matrix_sparse: "I_ij { (0, 0): 1, (1, 1): 2 }" expect "I" vec![1., 2.], + concatenate_sparse: "A_ij { (0, 0): 1, (1, 1): 2 } B_ij { (0:2, 0:2): A_ij, (2:4, 2:4): A_ij }" expect "B" vec![1., 2., 1., 2.], + sparse_rearrange: "A_ij { (0, 0): 1, (1, 1): 2, (0, 1): 3 }" expect "A" vec![1., 3., 2.], + sparse_rearrange2: "A_ij { (0, 1): 1, (1, 1): 2, (1, 0): 3, (2, 2): 4, (2, 1): 5 }" expect "A" vec![1., 3., 2., 5., 4.], + sparse_expression: "A_ij { (0, 0): 1, (0, 1): 2, (1, 1): 3 } B_ij { 2 * A_ij }" expect "B" vec![2., 4., 6.], + sparse_matrix_vect_multiply: "A_ij { (0, 0): 1, (1, 0): 2, (1, 1): 3 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![1., 8.], + sparse_rearrange_matrix_vect_multiply: "A_ij { (0, 1): 1, (1, 1): 2, (1, 0): 3, (2, 2): 4, (2, 1): 5 } x_i { 1, 2, 3 } b_i { A_ij * x_j }" expect "b" vec![2., 7., 22.], + diag_matrix_vect_multiply: "A_ij { (0, 0): 1, (1, 1): 3 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![1., 6.], + dense_matrix_vect_multiply: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![5., 11.], + sparse_matrix_vect_multiply_zero_row: "A_ij { (0, 1): 2 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![4.], + } + + macro_rules! tensor_grad_test { + ($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr,)*) => { + $( + #[test] + fn $name() { + let full_text = format!(" + in = [p] + p {{ + 1, + }} + u_i {{ + y = p, + }} + dudt_i {{ + dydt = p, + }} + {} + M_i {{ + dydt, + }} + F_i {{ + y, + }} + out_i {{ + y, + }} + ", $text); + let tmp_loc = format!("test_output/compiler_tensor_grad_test_{}", stringify!($name)); + let results = tensor_test_common(full_text.as_str(), tmp_loc.as_str(), $tensor_name); + assert_relative_eq!(results[1].as_slice(), $expected_value.as_slice()); + } + )* + } + } + + tensor_grad_test! { + const_grad: "r { 3 }" expect "r" vec![0.], + const_vec_grad: "r_i { 3, 4 }" expect "r" vec![0., 0.], + input_grad: "r { 2 * p * p }" expect "r" vec![4.], + input_vec_grad: "r_i { 2 * p * p, 3 * p }" expect "r" vec![4., 3.], + state_grad: "r { 2 * y }" expect "r" vec![2.], + input_and_state_grad: "r { 2 * y * p }" expect "r" vec![4.], + } + + #[test] + fn test_repeated_grad() { + let full_text = " + in = [p] + p { + 1, + } + u_i { + y = p, + } + dudt_i { + dydt = 1, + } + r { + 2 * y * p, + } + M_i { + dydt, + } + F_i { + r, + } + out_i { + y, + } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = match DiscreteModel::build("test_repeated_grad", &model) { + Ok(model) => model, + Err(e) => { + panic!("{}", e.as_error_message(full_text)); + } + }; + let compiler = LlvmCompiler::from_discrete_model( + &discrete_model, + "test_output/compiler_test_repeated_grad", + ) + .unwrap(); + let mut u0 = vec![1.]; + let mut du0 = vec![1.]; + let mut res = vec![0.]; + let mut dres = vec![0.]; + let mut data = compiler.get_new_data(); + let mut ddata = compiler.get_new_data(); + let (_n_states, n_inputs, _n_outputs, _n_data, _n_stop) = compiler.get_dims(); + + for _ in 0..3 { + let inputs = vec![2.; n_inputs]; + let dinputs = vec![1.; n_inputs]; + compiler.set_inputs_grad( + inputs.as_slice(), + dinputs.as_slice(), + data.as_mut_slice(), + ddata.as_mut_slice(), + ); + compiler.set_u0_grad( + u0.as_mut_slice(), + du0.as_mut_slice(), + data.as_mut_slice(), + ddata.as_mut_slice(), + ); + compiler.rhs_grad( + 0., + u0.as_slice(), + du0.as_slice(), + data.as_mut_slice(), + ddata.as_mut_slice(), + res.as_mut_slice(), + dres.as_mut_slice(), + ); + assert_relative_eq!(dres.as_slice(), vec![8.].as_slice()); + } + } + + #[test] + fn test_additional_functions() { + let full_text = " + in = [k] + k { + 1, + } + u_i { + y = 1, + x = 2, + } + dudt_i { + dydt = 0, + 0, + } + M_i { + dydt, + 0, + } + F_i { + y - 1, + x - 2, + } + out_i { + y, + x, + 2*x, + } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = LlvmCompiler::from_discrete_model( + &discrete_model, + "test_output/compiler_test_additional_functions", + ) + .unwrap(); + let (n_states, n_inputs, n_outputs, n_data, _n_stop) = compiler.get_dims(); + assert_eq!(n_states, 2); + assert_eq!(n_inputs, 1); + assert_eq!(n_outputs, 3); + assert_eq!(n_data, compiler.data_len()); + + let mut data = compiler.get_new_data(); + let inputs = vec![1.1]; + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + + let inputs = compiler.get_tensor_data("k", data.as_slice()).unwrap(); + assert_relative_eq!(inputs, vec![1.1].as_slice()); + + let mut id = vec![0.0, 0.0]; + compiler.set_id(id.as_mut_slice()); + assert_eq!(id, vec![1.0, 0.0]); + + let mut u = vec![0., 0.]; + compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); + assert_relative_eq!(u.as_slice(), vec![1., 2.].as_slice()); + + let mut rr = vec![1., 1.]; + compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); + assert_relative_eq!(rr.as_slice(), vec![0., 0.].as_slice()); + + let up = vec![2., 3.]; + rr = vec![1., 1.]; + compiler.mass(0., up.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); + assert_relative_eq!(rr.as_slice(), vec![2., 0.].as_slice()); + + compiler.calc_out(0., u.as_slice(), data.as_mut_slice()); + let out = compiler.get_out(data.as_slice()); + assert_relative_eq!(out, vec![1., 2., 4.].as_slice()); + } +} diff --git a/src/execution/llvm/mod.rs b/src/execution/llvm/mod.rs new file mode 100644 index 0000000..24ccbdd --- /dev/null +++ b/src/execution/llvm/mod.rs @@ -0,0 +1 @@ +pub mod codegen; diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 3d8c7e9..01b91d6 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -1,14 +1,14 @@ -pub mod codegen; -pub use codegen::CodeGen; +#[cfg(feature = "llvm")] +pub mod llvm; + +pub mod compiler; +pub mod cranelift; +pub mod functions; +pub mod interface; +pub mod module; pub mod data_layout; pub use data_layout::DataLayout; pub mod translation; pub use translation::{Translation, TranslationFrom, TranslationTo}; - -// todo: this warning is coming from the ourbouros crate, -// remove this when the ourbouros crate is updated -#[allow(clippy::too_many_arguments)] -pub mod compiler; -pub use compiler::Compiler; diff --git a/src/execution/module.rs b/src/execution/module.rs new file mode 100644 index 0000000..b7d10c6 --- /dev/null +++ b/src/execution/module.rs @@ -0,0 +1,49 @@ +use anyhow::Result; +use target_lexicon::Triple; + +use crate::discretise::DiscreteModel; + +use super::DataLayout; + +pub trait CodegenModule: Sized { + type FuncId; + + fn new(triple: Triple, model: &DiscreteModel) -> Result; + fn compile_set_u0(&mut self, model: &DiscreteModel) -> Result; + fn compile_calc_out(&mut self, model: &DiscreteModel) -> Result; + fn compile_calc_stop(&mut self, model: &DiscreteModel) -> Result; + fn compile_rhs(&mut self, model: &DiscreteModel) -> Result; + fn compile_mass(&mut self, model: &DiscreteModel) -> Result; + fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result; + fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result; + fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result; + fn compile_set_id(&mut self, model: &DiscreteModel) -> Result; + + fn compile_set_u0_grad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + fn compile_rhs_grad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + fn compile_calc_out_grad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + fn compile_set_inputs_grad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + + fn jit(&mut self, func_id: Self::FuncId) -> Result<*const u8>; + + fn pre_autodiff_optimisation(&mut self) -> Result<()>; + fn post_autodiff_optimisation(&mut self) -> Result<()>; + + fn layout(&self) -> &DataLayout; +} diff --git a/src/lib.rs b/src/lib.rs index 9eb6923..d750de9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,3 @@ -use anyhow::{anyhow, Result}; -use continuous::ModelInfo; -use discretise::DiscreteModel; -use execution::Compiler; -use parser::{parse_ds_string, parse_ms_string}; -use std::{ffi::OsStr, path::Path}; - extern crate pest; #[macro_use] extern crate pest_derive; @@ -12,11 +5,17 @@ extern crate pest_derive; pub mod ast; pub mod continuous; pub mod discretise; +#[cfg(feature = "enzyme")] pub mod enzyme; pub mod execution; pub mod parser; pub mod utils; +pub use execution::compiler::Compiler; +pub use execution::cranelift::codegen::CraneliftModule; +#[cfg(feature = "llvm")] +pub use execution::llvm::codegen::LlvmModule; + #[cfg(feature = "inkwell-130")] extern crate inkwell_130 as inkwell; #[cfg(feature = "inkwell-140")] @@ -27,6 +26,8 @@ extern crate inkwell_150 as inkwell; extern crate inkwell_160 as inkwell; #[cfg(feature = "inkwell-170")] extern crate inkwell_170 as inkwell; +#[cfg(feature = "inkwell-180")] +extern crate inkwell_180 as inkwell; #[cfg(feature = "inkwell-130")] extern crate llvm_sys_130 as llvm_sys; @@ -38,171 +39,5 @@ extern crate llvm_sys_150 as llvm_sys; extern crate llvm_sys_160 as llvm_sys; #[cfg(feature = "inkwell-170")] extern crate llvm_sys_170 as llvm_sys; - -pub struct CompilerOptions { - pub bitcode_only: bool, - pub wasm: bool, - pub standalone: bool, -} - -pub fn compile( - input: &str, - out: Option<&str>, - model: Option<&str>, - options: CompilerOptions, -) -> Result<()> { - let inputfile = Path::new(input); - let is_discrete = inputfile - .extension() - .unwrap_or(OsStr::new("")) - .to_str() - .unwrap() - == "ds"; - let is_continuous = inputfile - .extension() - .unwrap_or(OsStr::new("")) - .to_str() - .unwrap() - == "cs"; - if !is_discrete && !is_continuous { - panic!("Input file must have extension .ds or .cs"); - } - let model_name = if is_continuous { - if let Some(model_name) = model { - model_name - } else { - return Err(anyhow!( - "Model name must be specified for continuous models" - )); - } - } else { - inputfile.file_stem().unwrap().to_str().unwrap() - }; - let out = out.unwrap_or("out"); - let text = std::fs::read_to_string(inputfile)?; - compile_text(text.as_str(), out, model_name, options, is_discrete) -} - -pub fn compile_text( - text: &str, - out: &str, - model_name: &str, - options: CompilerOptions, - is_discrete: bool, -) -> Result<()> { - let is_continuous = !is_discrete; - - let continuous_ast = if is_continuous { - Some(parse_ms_string(text)?) - } else { - None - }; - - let discrete_ast = if is_discrete { - Some(parse_ds_string(text)?) - } else { - None - }; - - let continuous_model_info = if let Some(ast) = &continuous_ast { - let model_info = ModelInfo::build(model_name, ast).map_err(|e| anyhow!("{}", e))?; - if !model_info.errors.is_empty() { - let error_text = model_info.errors.iter().fold(String::new(), |acc, error| { - format!("{}\n{}", acc, error.as_error_message(text)) - }); - return Err(anyhow!(error_text)); - } - Some(model_info) - } else { - None - }; - - let discrete_model = if let Some(model_info) = &continuous_model_info { - let model = DiscreteModel::from(model_info); - model - } else if let Some(ast) = &discrete_ast { - match DiscreteModel::build(model_name, ast) { - Ok(model) => model, - Err(e) => { - return Err(anyhow!(e.as_error_message(text))); - } - } - } else { - panic!("No model found"); - }; - let compiler = Compiler::from_discrete_model(&discrete_model, out)?; - - if options.bitcode_only { - return Ok(()); - } - - compiler.compile(options.standalone, options.wasm) -} - -#[cfg(test)] -mod tests { - use crate::{ - continuous::ModelInfo, - parser::{parse_ds_string, parse_ms_string}, - }; - use approx::assert_relative_eq; - - use super::*; - - fn ds_example_compiler(example: &str) -> Compiler { - let text = std::fs::read_to_string(format!("examples/{}.ds", example)).unwrap(); - let model = parse_ds_string(text.as_str()).unwrap(); - let model = DiscreteModel::build(example, &model) - .unwrap_or_else(|e| panic!("{}", e.as_error_message(text.as_str()))); - let out = format!("test_output/lib_examples_{}", example); - Compiler::from_discrete_model(&model, out.as_str()).unwrap() - } - - #[test] - fn test_logistic_ds_example() { - let compiler = ds_example_compiler("logistic"); - let r = 0.5; - let k = 0.5; - let y = 0.5; - let dydt = r * y * (1. - y / k); - let z = 2. * y; - let dzdt = 2. * dydt; - let inputs = vec![r, k]; - let mut u0 = vec![y, z]; - let mut data = compiler.get_new_data(); - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); - - u0 = vec![y, z]; - let up0 = vec![dydt, dzdt]; - let mut res = vec![1., 1.]; - - compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); - let expected_value = vec![dydt, 2.0 * y - z]; - assert_relative_eq!(res.as_slice(), expected_value.as_slice()); - - compiler.mass(0., up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); - let expected_value = vec![dydt, 0.]; - assert_relative_eq!(res.as_slice(), expected_value.as_slice()); - } - - #[test] - fn test_object_file() { - let text = " - model logistic_growth(r -> NonNegative, k -> NonNegative, y(t), z(t)) { - dot(y) = r * y * (1 - y / k) - y(0) = 1.0 - z = 2 * y - } - "; - let models = parse_ms_string(text).unwrap(); - let model_info = ModelInfo::build("logistic_growth", &models).unwrap(); - assert_eq!(model_info.errors.len(), 0); - let discrete_model = DiscreteModel::from(&model_info); - let object = - Compiler::from_discrete_model(&discrete_model, "test_output/lib_test_object_file") - .unwrap(); - let path = Path::new("main.o"); - object.write_object_file(path).unwrap(); - } -} +#[cfg(feature = "inkwell-180")] +extern crate llvm_sys_180 as llvm_sys; diff --git a/src/parser/ds_parser.rs b/src/parser/ds_parser.rs index aac584f..287904b 100644 --- a/src/parser/ds_parser.rs +++ b/src/parser/ds_parser.rs @@ -39,7 +39,11 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { match pair.as_rule() { // name = @{ 'a'..'z' ~ ("_" | 'a'..'z' | 'A'..'Z' | '0'..'9')* } Rule::name => Ast { - kind: AstKind::Name(pair.as_str()), + kind: AstKind::Name(ast::Name { + name: pair.as_str(), + indices: vec![], + is_tangent: false, + }), span, }, @@ -74,6 +78,7 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { kind: AstKind::Call(ast::Call { fn_name: parse_name(inner.next().unwrap()), args: inner.map(parse_value).map(Box::new).collect(), + is_tangent: false, }), span, } @@ -154,7 +159,11 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { vec![] }; Ast { - kind: AstKind::IndexedName(ast::IndexedName { name, indices }), + kind: AstKind::Name(ast::Name { + name, + indices, + is_tangent: false, + }), span, } } @@ -185,7 +194,11 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { let mut inner = pair.into_inner(); let name_ij = parse_value(inner.next().unwrap()); let (name, indices) = match name_ij.kind { - AstKind::IndexedName(ast::IndexedName { name, indices }) => (name, indices), + AstKind::Name(ast::Name { + name, + indices, + is_tangent: false, + }) => (name, indices), _ => unreachable!(), }; let elmts = inner.map(|v| parse_value(v)).collect(); diff --git a/src/parser/ms_parser.rs b/src/parser/ms_parser.rs index ddef667..c5e882d 100644 --- a/src/parser/ms_parser.rs +++ b/src/parser/ms_parser.rs @@ -40,7 +40,11 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { // name = @{ 'a'..'z' ~ ("_" | 'a'..'z' | 'A'..'Z' | '0'..'9')* } // domain_name = @{ 'A'..'Z' ~ ('a'..'z' | 'A'..'Z' | '0'..'9')* } Rule::name | Rule::domain_name => Ast { - kind: AstKind::Name(pair.as_str()), + kind: AstKind::Name(ast::Name { + name: pair.as_str(), + indices: vec![], + is_tangent: false, + }), span, }, @@ -150,6 +154,7 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { kind: AstKind::Call(ast::Call { fn_name: parse_name(inner.next().unwrap()), args: inner.map(parse_value).map(Box::new).collect(), + is_tangent: false, }), span, } @@ -160,7 +165,12 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { // TODO: is there a better way of destructuring this? let mut inner = pair.into_inner(); let (name, args) = if let Ast { - kind: AstKind::Call(ast::Call { fn_name, args }), + kind: + AstKind::Call(ast::Call { + fn_name, + args, + is_tangent: _, + }), span: _, } = parse_value(inner.next().unwrap()) { @@ -356,7 +366,7 @@ mod tests { } assert_eq!(models[0].statements.len(), 1); if let AstKind::Equation(eqn) = &models[0].statements[0].kind { - assert!(matches!(eqn.lhs.kind, AstKind::Name(name) if name == "i")); + assert!(matches!(&eqn.lhs.kind, AstKind::Name(name) if name.name == "i")); assert!(matches!(&eqn.rhs.kind, AstKind::Binop(binop) if binop.op == '*')); } else { panic!("not an equation") @@ -424,7 +434,7 @@ mod tests { if let AstKind::CallArg(arg) = &submodel.args[0].kind { assert_eq!(arg.name.unwrap(), "v"); assert!( - matches!(arg.expression.kind, AstKind::Name(name) if name == "inputVoltage") + matches!(&arg.expression.kind, AstKind::Name(name) if name.name == "inputVoltage") ); } else { unreachable!("not a call arg")