Skip to content

Commit

Permalink
fix unary negate before plus bug
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Nov 26, 2023
1 parent e4c18fc commit 93f8e46
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
5 changes: 3 additions & 2 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ mod tests {

tensor_test!{
exp_function: "r { exp(2) }" expect "r" vec![f64::exp(2.0)],
pow_function: "r { pow(4, 0.5) }" expect "r" vec![2.0],
pow_function: "r { pow(4.3245, 0.5) }" expect "r" vec![f64::powf(4.3245, 0.5)],
arcsinh_function: "r { arcsinh(0.5) }" expect "r" vec![f64::asinh(0.5)],
arccosh_function: "r { arccosh(2) }" expect "r" vec![f64::acosh(2.0)],
tanh_function: "r { tanh(0.5) }" expect "r" vec![f64::tanh(0.5)],
Expand All @@ -763,7 +763,8 @@ mod tests {
sigmoid_function: "r { sigmoid(0.1) }" expect "r" vec![1.0 / (1.0 + f64::exp(-0.1))],
scalar: "r {2}" expect "r" vec![2.0,],
constant: "r_i {2, 3}" expect "r" vec![2., 3.],
expression: "r_i {2 + 3, 3 * 2}" expect "r" vec![5., 6.],
expression: "r_i {2 + 3, 3 * 2, arcsinh(1.2 + 1.0 / max(1.2, 1.0) * 2.0 + tanh(2.0))}" expect "r" vec![5., 6., f64::asinh(1.2 + 1.0 / f64::max(1.2, 1.0) * 2.0 + f64::tanh(2.0))],
unary_negate_in_expr: "r_i { 1.0 / (-1.0 + 1.1) }" expect "r" vec![1.0 / (-1.0 + 1.1)],
derived: "r_i {2, 3} k_i { 2 * r_i }" expect "k" vec![4., 6.],
concatenate: "r_i {2, 3} k_i { r_i, 2 * r_i }" expect "k" vec![2., 3., 4., 6.],
ones_matrix_dense: "I_ij { (0:2, 0:2): 1 }" expect "I" vec![1., 1., 1., 1.],
Expand Down
23 changes: 11 additions & 12 deletions src/parser/ds_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,16 @@ fn parse_value<'a, 'b>(pair: Pair<'a, Rule>) -> Ast<'a> {
} else {
None
};
let mut head_term = parse_value(inner.next().unwrap());
let mut head_term = match sign {
Some(s) => Ast {
kind: AstKind::Monop(ast::Monop {
op: s,
child: Box::new(parse_value(inner.next().unwrap())),
}),
span,
},
None => parse_value(inner.next().unwrap())
};
while inner.peek().is_some() {
//term_op = @{ "-"|"+" }
let term_op = parse_sign(inner.next().unwrap());
Expand All @@ -104,17 +113,7 @@ fn parse_value<'a, 'b>(pair: Pair<'a, Rule>) -> Ast<'a> {
span: subspan,
};
}
if sign.is_some() {
Ast {
kind: AstKind::Monop(ast::Monop {
op: sign.unwrap(),
child: Box::new(head_term),
}),
span,
}
} else {
head_term
}
head_term
}

//term = { factor ~ (factor_op ~ factor)* }
Expand Down

0 comments on commit 93f8e46

Please sign in to comment.