Skip to content

Commit

Permalink
cargo fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Oct 3, 2024
1 parent 1baf6df commit 863ec13
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
7 changes: 4 additions & 3 deletions src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,16 @@ impl<'s> DiscreteModel<'s> {
}
}
}

// reorder inputs to match the order defined in "in = [ ... ]"
ret.inputs.sort_by_key(|t| {
model.inputs
model
.inputs
.iter()
.position(|&name| name == t.name())
.unwrap()
});

// set is_algebraic for every state based on equations
if ret.state_dot.is_some() && ret.lhs.is_some() {
let state_dot = ret.state_dot.as_ref().unwrap();
Expand Down
14 changes: 7 additions & 7 deletions src/execution/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ mod tests {
assert_relative_eq!(stop[0], 0.5);
assert_eq!(stop.len(), 1);
}

#[test]
fn test_vector_add_scalar_cranelift() {
let n = 1;
Expand Down Expand Up @@ -622,8 +622,6 @@ mod tests {
}

fn tensor_test_common<T: CodegenModule>(text: &str, tensor_name: &str) -> Vec<Vec<f64>> {


let full_text = format!(
"
{}
Expand Down Expand Up @@ -1011,7 +1009,7 @@ mod tests {
let out = compiler.get_out(data.as_slice());
assert_relative_eq!(out, vec![1., 2., 4.].as_slice());
}

#[test]
fn test_inputs() {
let full_text = "
Expand All @@ -1033,15 +1031,17 @@ mod tests {
let inputs = compiler.get_tensor_data(name, data.as_slice()).unwrap();
assert_relative_eq!(inputs, expected_value.as_slice());
}

#[cfg(feature = "llvm")]
{
let compiler = Compiler::<crate::LlvmModule>::from_discrete_model(&discrete_model).unwrap();
let compiler =
Compiler::<crate::LlvmModule>::from_discrete_model(&discrete_model).unwrap();
let mut data = compiler.get_new_data();
let inputs = vec![1.0, 2.0, 3.0];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice());

for (name, expected_value) in vec![("a", vec![2.0]), ("b", vec![3.0]), ("c", vec![1.0])] {
for (name, expected_value) in vec![("a", vec![2.0]), ("b", vec![3.0]), ("c", vec![1.0])]
{
let inputs = compiler.get_tensor_data(name, data.as_slice()).unwrap();
assert_relative_eq!(inputs, expected_value.as_slice());
}
Expand Down
23 changes: 11 additions & 12 deletions src/execution/cranelift/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl CraneliftModule {

// Now that compilation is finished, we can clear out the context state.
self.module.clear_context(&mut self.ctx);

Ok(id)
}
}
Expand Down Expand Up @@ -894,12 +894,11 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
};

//let expr_index_var = self.decl_stack_slot(self.int_type, Some(zero));
let elmt_index_var = if contract_sum.is_some() {
let elmt_index_var = if contract_sum.is_some() {
Some(self.decl_stack_slot(self.int_type, Some(zero)))
} else {
None
};


for i in 0..expr_rank {
let block = self.builder.create_block();
Expand Down Expand Up @@ -934,8 +933,7 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
} else {
elmt.expr()
};
let float_value =
self.jit_compile_expr(name, expr, indices.as_slice(), elmt, None)?;
let float_value = self.jit_compile_expr(name, expr, indices.as_slice(), elmt, None)?;

if contract_sum.is_some() {
let contract_sum_value =
Expand All @@ -950,9 +948,10 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
let expr_index = if indices.is_empty() {
zero
} else {
indices.iter().skip(1).fold(indices[0], |acc, x| {
self.builder.ins().imul(acc, *x)
})
indices
.iter()
.skip(1)
.fold(indices[0], |acc, x| self.builder.ins().imul(acc, *x))
};
self.jit_compile_broadcast_and_store(
name,
Expand All @@ -972,10 +971,10 @@ impl<'ctx> CraneliftCodeGen<'ctx> {
for i in (0..expr_rank).rev() {
// update and store contract sum
if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
let elmt_index = self
.builder
.ins()
.stack_load(self.int_type, elmt_index_var.unwrap(), 0);
let elmt_index =
self.builder
.ins()
.stack_load(self.int_type, elmt_index_var.unwrap(), 0);
let next_elmt_index = self.builder.ins().iadd(elmt_index, one);
self.builder
.ins()
Expand Down

0 comments on commit 863ec13

Please sign in to comment.