Skip to content

Commit

Permalink
remove sundials
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 5, 2024
1 parent f1d45cc commit 5e68fda
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 469 deletions.
30 changes: 18 additions & 12 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
name: Cargo Build & Test
name: Rust

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

env:
env:
CARGO_TERM_COLOR: always

jobs:
build_and_test:
build:
name: Rust project - latest
runs-on: ubuntu-latest
strategy:
matrix:
toolchain:
- stable
- beta
- nightly
matrix:
toolchain:
- stable
- beta
# - nightly

steps:
- uses: actions/checkout@v3
- run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }}
- run: cargo build --verbose
- run: cargo test --verbose
- uses: actions/checkout@v3
- name: Set up Rust
run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }}
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pest = ">=2.1.3"
pest_derive = ">=2.1.0"
itertools = ">=0.10.3"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm14-0"] }
sundials-sys = { version = ">=0.3", features = ["idas", "build_libraries", "static_libraries"] }
ouroboros = ">=0.17"
clap = { version = "4.3.23", features = ["derive"] }

Expand Down
20 changes: 8 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,32 @@ out_i {
x,
y,
z,
}
```

## Installation
## Dependencies

This package relies on the `clang` (https://clang.llvm.org/) and `opt`
executables from the [LLVM project](https://llvm.org/). The easiest way to
This package uses the `opt` executable from the [LLVM project](https://llvm.org/). The easiest way to
install these is to use the package manager for your operating system. For
example, on Ubuntu you can install these with the following command:

```bash
sudo apt-get install clang
```

In addition, DiffSL uses the [Enzyme AD](https://enzyme.mit.edu/) package for automatic differentiation. This can be installed by following the instructions on the Enzyme AD website. You will need set the `LIBRARY_PATH` environment variable to the location of the Enzyme AD library. For example, if you have installed Enzyme AD in the directory `/usr/local`, you can set the `LIBRARY_PATH` environment variable with the following command:
In addition, DiffSL uses the [Enzyme AD](https://enzyme.mit.edu/) package for automatic differentiation. This can be installed by following the instructions on the Enzyme AD website. You will need set the `ENZYME_LIB` environment variable to the location of the Enzyme AD library. Please make sure that you compile the Enzyme AD library with the version of LLVM that corresponds to the version of `opt` that you have on your path.

```bash
export LIBRARY_PATH=/usr/local/lib
export ENZYME_LIB=<path to Enzyme AD library>
```

## Building DiffSL





## Installing DiffSL

### Installing Enzyme AD
You can install DiffSL using cargo:

```bash
cmake -DCMAKE_INSTALL_PREFIX=<install> -DCMAKE_BUILD_TYPE=Release ..
cargo add diffsl
```

## DiffSL Language
Expand Down
27 changes: 14 additions & 13 deletions src/execution/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use inkwell::module::{Module, Linkage};
use std::collections::HashMap;
use std::iter::zip;
use anyhow::{Result, anyhow};
use sundials_sys::realtype;

type RealType = f64;


use crate::ast::{Ast, AstKind};
Expand All @@ -20,18 +21,18 @@ use crate::execution::{Translation, TranslationFrom, TranslationTo, DataLayout};
///
/// Calling this is innately `unsafe` because there's no guarantee it doesn't
/// do `unsafe` operations internally.
pub type StopFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, up: *const realtype, data: *mut realtype, root: *mut realtype);
pub type ResidualFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, up: *const realtype, data: *mut realtype, rr: *mut realtype);
pub type ResidualGradientFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, du: *const realtype, up: *const realtype, dup: *const realtype, data: *mut realtype, ddata: *mut realtype, rr: *mut realtype, drr: *mut realtype);
pub type U0Func = unsafe extern "C" fn(data: *mut realtype, u: *mut realtype, up: *mut realtype);
pub type U0GradientFunc = unsafe extern "C" fn(data: *mut realtype, ddata: *mut realtype, u: *mut realtype, du: *mut realtype, up: *mut realtype, dup: *mut realtype);
pub type CalcOutFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, up: *const realtype, data: *mut realtype);
pub type CalcOutGradientFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, du: *const realtype, up: *const realtype, dup: *const realtype, data: *mut realtype, ddata: *mut realtype);
pub type StopFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, up: *const RealType, data: *mut RealType, root: *mut RealType);
pub type ResidualFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, up: *const RealType, data: *mut RealType, rr: *mut RealType);
pub type ResidualGradientFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, du: *const RealType, up: *const RealType, dup: *const RealType, data: *mut RealType, ddata: *mut RealType, rr: *mut RealType, drr: *mut RealType);
pub type U0Func = unsafe extern "C" fn(data: *mut RealType, u: *mut RealType, up: *mut RealType);
pub type U0GradientFunc = unsafe extern "C" fn(data: *mut RealType, ddata: *mut RealType, u: *mut RealType, du: *mut RealType, up: *mut RealType, dup: *mut RealType);
pub type CalcOutFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, up: *const RealType, data: *mut RealType);
pub type CalcOutGradientFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, du: *const RealType, up: *const RealType, dup: *const RealType, data: *mut RealType, ddata: *mut RealType);
pub type GetDimsFunc = unsafe extern "C" fn(states: *mut u32, inputs: *mut u32, outputs: *mut u32, data: *mut u32, stop: *mut u32);
pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const realtype, data: *mut realtype);
pub type SetInputsGradientFunc = unsafe extern "C" fn(inputs: *const realtype, dinputs: *const realtype, data: *mut realtype, ddata: *mut realtype);
pub type SetIdFunc = unsafe extern "C" fn(id: *mut realtype);
pub type GetOutFunc = unsafe extern "C" fn(data: *const realtype, tensor_data: *mut *mut realtype, tensor_size: *mut u32);
pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const RealType, data: *mut RealType);
pub type SetInputsGradientFunc = unsafe extern "C" fn(inputs: *const RealType, dinputs: *const RealType, data: *mut RealType, ddata: *mut RealType);
pub type SetIdFunc = unsafe extern "C" fn(id: *mut RealType);
pub type GetOutFunc = unsafe extern "C" fn(data: *const RealType, tensor_data: *mut *mut RealType, tensor_size: *mut u32);

struct Globals<'ctx> {
enzyme_dup: GlobalValue<'ctx>,
Expand Down Expand Up @@ -1385,7 +1386,7 @@ impl<'ctx> CodeGen<'ctx> {
let curr_blk_index = index.as_basic_value().into_int_value();
let curr_id_index = self.builder.build_int_add(id_start_index, curr_blk_index, name)?;
let id_ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("id"), &[curr_id_index], name)? };
let is_algebraic_float = if *is_algebraic { 0.0 as realtype } else { 1.0 as realtype };
let is_algebraic_float = if *is_algebraic { 0.0 as RealType } else { 1.0 as RealType };
let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
self.builder.build_store(id_ptr, is_algebraic_value)?;

Expand Down
9 changes: 6 additions & 3 deletions src/execution/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::path::Path;
use anyhow::anyhow;
use std::env;

use anyhow::Result;
use inkwell::memory_buffer::MemoryBuffer;
Expand All @@ -9,7 +10,6 @@ use inkwell::{context::Context, OptimizationLevel, targets::{TargetTriple, Initi
use ouroboros::self_referencing;
use crate::discretise::DiscreteModel;
use crate::utils::find_executable;
use crate::utils::find_library_path;
use crate::utils::find_runtime_path;
use std::process::Command;

Expand Down Expand Up @@ -77,8 +77,11 @@ impl Compiler {
find_executable(&Compiler::CLANG_VARIENTS)
}
fn find_enzyme_lib() -> Result<String> {
let enzyme_lib_varients = ["LLVMEnzyme-14.so", "LLVMEnzyme-14.dylib"];
find_library_path(&enzyme_lib_varients)
match env::var("ENZYME_LIB") {
Ok(lib) => Ok(lib),
Err(_) => Err(anyhow!("ENZYME_LIB environment variable not set")),

}
}
pub fn from_discrete_model(model: &DiscreteModel, out: &str) -> Result<Self> {
let number_of_states = *model.state().shape().first().unwrap_or(&1);
Expand Down
3 changes: 0 additions & 3 deletions src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ pub use codegen::CodeGen;
pub mod data_layout;
pub use data_layout::DataLayout;

pub mod sundials;
pub use sundials::{Sundials, Options};

pub mod translation;
pub use translation::{Translation, TranslationFrom, TranslationTo};

Expand Down
Loading

0 comments on commit 5e68fda

Please sign in to comment.