Skip to content

Commit

Permalink
bug: fix for case if inputs come in different order (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins authored Oct 3, 2024
1 parent a026444 commit 3f40ae7
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 25 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ cmake = { version = "0.1.50", optional = true }

[dev-dependencies]
divan = "0.1.14"
env_logger = "0.11.5"

[[bench]]
name = "evaluation"
Expand Down
9 changes: 9 additions & 0 deletions src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,15 @@ impl<'s> DiscreteModel<'s> {
}
}

// reorder inputs to match the order defined in "in = [ ... ]"
ret.inputs.sort_by_key(|t| {
model
.inputs
.iter()
.position(|&name| name == t.name())
.unwrap()
});

// set is_algebraic for every state based on equations
if ret.state_dot.is_some() && ret.lhs.is_some() {
let state_dot = ret.state_dot.as_ref().unwrap();
Expand Down
66 changes: 66 additions & 0 deletions src/execution/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,34 @@ mod tests {
assert_eq!(stop.len(), 1);
}

#[test]
fn test_vector_add_scalar_cranelift() {
let n = 1;
let u = vec![1.0; n];
let full_text = format!(
"
u_i {{
{}
}}
F_i {{
u_i + 1.0,
}}
out_i {{
u_i
}}
",
(0..n)
.map(|i| format!("x{} = {},", i, u[i]))
.collect::<Vec<_>>()
.join("\n"),
);
let model = parse_ds_string(&full_text).unwrap();
let name = "$name";
let discrete_model = DiscreteModel::build(name, &model).unwrap();
env_logger::builder().is_test(true).try_init().unwrap();
let _compiler = Compiler::<CraneliftModule>::from_discrete_model(&discrete_model).unwrap();
}

fn tensor_test_common<T: CodegenModule>(text: &str, tensor_name: &str) -> Vec<Vec<f64>> {
let full_text = format!(
"
Expand Down Expand Up @@ -981,4 +1009,42 @@ mod tests {
let out = compiler.get_out(data.as_slice());
assert_relative_eq!(out, vec![1., 2., 4.].as_slice());
}

#[test]
fn test_inputs() {
let full_text = "
in = [c, a, b]
a { 1 } b { 2 } c { 3 }
u { y = 0 }
F { y }
out { y }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("test_inputs", &model).unwrap();

let compiler = Compiler::<CraneliftModule>::from_discrete_model(&discrete_model).unwrap();
let mut data = compiler.get_new_data();
let inputs = vec![1.0, 2.0, 3.0];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice());

for (name, expected_value) in vec![("a", vec![2.0]), ("b", vec![3.0]), ("c", vec![1.0])] {
let inputs = compiler.get_tensor_data(name, data.as_slice()).unwrap();
assert_relative_eq!(inputs, expected_value.as_slice());
}

#[cfg(feature = "llvm")]
{
let compiler =
Compiler::<crate::LlvmModule>::from_discrete_model(&discrete_model).unwrap();
let mut data = compiler.get_new_data();
let inputs = vec![1.0, 2.0, 3.0];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice());

for (name, expected_value) in vec![("a", vec![2.0]), ("b", vec![3.0]), ("c", vec![1.0])]
{
let inputs = compiler.get_tensor_data(name, data.as_slice()).unwrap();
assert_relative_eq!(inputs, expected_value.as_slice());
}
}
}
}
60 changes: 35 additions & 25 deletions src/execution/cranelift/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl CraneliftModule {
.module
.declare_function(name, Linkage::Export, &self.ctx.func.signature)?;

//println!("Declared function: {}", name);
//println!("Declared function: {} -------------------------------------------------------------------------------------", name);
//println!("IR:\n{}", self.ctx.func);

// Define the function to jit. This finishes compilation, although
Expand Down Expand Up @@ -263,7 +263,7 @@ impl CodegenModule for CraneliftModule {

// write indices data as a global data object
// convect the indices to bytes
let int_type = types::I32;
let int_type = ptr_type;
let real_type = types::F64;
let mut vec8: Vec<u8> = vec![];
for elem in layout.indices() {
Expand Down Expand Up @@ -875,9 +875,6 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
let one = self.builder.ins().iconst(int_type, 1);
let zero = self.builder.ins().iconst(int_type, 0);

let expr_index_var = self.decl_stack_slot(self.int_type, Some(zero));
let elmt_index_var = self.decl_stack_slot(self.int_type, Some(zero));

// setup indices, loop through the nested loops
let mut indices = Vec::new();
let mut blocks = Vec::new();
Expand All @@ -896,6 +893,13 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
(None, 0)
};

//let expr_index_var = self.decl_stack_slot(self.int_type, Some(zero));
let elmt_index_var = if contract_sum.is_some() {
Some(self.decl_stack_slot(self.int_type, Some(zero)))
} else {
None
};

for i in 0..expr_rank {
let block = self.builder.create_block();
let curr_index = self.builder.append_block_param(block, self.int_type);
Expand All @@ -914,28 +918,22 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
preblock = block;
}

let elmt_index = self
.builder
.ins()
.stack_load(self.int_type, elmt_index_var, 0);

// load and increment the expression index
let expr_index = self
.builder
.ins()
.stack_load(self.int_type, expr_index_var, 0);
let next_expr_index = self.builder.ins().iadd(expr_index, one);
self.builder
.ins()
.stack_store(next_expr_index, expr_index_var, 0);
//let expr_index = self
// .builder
// .ins()
// .stack_load(self.int_type, expr_index_var, 0);
//let next_expr_index = self.builder.ins().iadd(expr_index, one);
//self.builder
// .ins()
// .stack_store(next_expr_index, expr_index_var, 0);

let expr = if is_tangent {
elmt.tangent_expr()
} else {
elmt.expr()
};
let float_value =
self.jit_compile_expr(name, expr, indices.as_slice(), elmt, Some(expr_index))?;
let float_value = self.jit_compile_expr(name, expr, indices.as_slice(), elmt, None)?;

if contract_sum.is_some() {
let contract_sum_value =
Expand All @@ -947,6 +945,14 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
.ins()
.stack_store(new_contract_sum_value, contract_sum.unwrap(), 0);
} else {
let expr_index = if indices.is_empty() {
zero
} else {
indices
.iter()
.skip(1)
.fold(indices[0], |acc, x| self.builder.ins().imul(acc, *x))
};
self.jit_compile_broadcast_and_store(
name,
elmt,
Expand All @@ -955,20 +961,24 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
translation,
preblock,
)?;
let next_elmt_index = self.builder.ins().iadd(elmt_index, one);
self.builder
.ins()
.stack_store(next_elmt_index, elmt_index_var, 0);
//let next_elmt_index = self.builder.ins().iadd(elmt_index, one);
//self.builder
// .ins()
// .stack_store(next_elmt_index, elmt_index_var, 0);
}

// unwind the nested loops
for i in (0..expr_rank).rev() {
// update and store contract sum
if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
let elmt_index =
self.builder
.ins()
.stack_load(self.int_type, elmt_index_var.unwrap(), 0);
let next_elmt_index = self.builder.ins().iadd(elmt_index, one);
self.builder
.ins()
.stack_store(next_elmt_index, elmt_index_var, 0);
.stack_store(next_elmt_index, elmt_index_var.unwrap(), 0);

let contract_sum_value =
self.builder
Expand Down

0 comments on commit 3f40ae7

Please sign in to comment.