diff --git a/src/codegen/compiler.rs b/src/codegen/compiler.rs index c114163..346f81b 100644 --- a/src/codegen/compiler.rs +++ b/src/codegen/compiler.rs @@ -751,7 +751,7 @@ mod tests { tensor_test!{ exp_function: "r { exp(2) }" expect "r" vec![f64::exp(2.0)], - pow_function: "r { pow(4, 0.5) }" expect "r" vec![2.0], + pow_function: "r { pow(4.3245, 0.5) }" expect "r" vec![f64::powf(4.3245, 0.5)], arcsinh_function: "r { arcsinh(0.5) }" expect "r" vec![f64::asinh(0.5)], arccosh_function: "r { arccosh(2) }" expect "r" vec![f64::acosh(2.0)], tanh_function: "r { tanh(0.5) }" expect "r" vec![f64::tanh(0.5)], @@ -763,7 +763,8 @@ mod tests { sigmoid_function: "r { sigmoid(0.1) }" expect "r" vec![1.0 / (1.0 + f64::exp(-0.1))], scalar: "r {2}" expect "r" vec![2.0,], constant: "r_i {2, 3}" expect "r" vec![2., 3.], - expression: "r_i {2 + 3, 3 * 2}" expect "r" vec![5., 6.], + expression: "r_i {2 + 3, 3 * 2, arcsinh(1.2 + 1.0 / max(1.2, 1.0) * 2.0 + tanh(2.0))}" expect "r" vec![5., 6., f64::asinh(1.2 + 1.0 / f64::max(1.2, 1.0) * 2.0 + f64::tanh(2.0))], + unary_negate_in_expr: "r_i { 1.0 / (-1.0 + 1.1) }" expect "r" vec![1.0 / (-1.0 + 1.1)], derived: "r_i {2, 3} k_i { 2 * r_i }" expect "k" vec![4., 6.], concatenate: "r_i {2, 3} k_i { r_i, 2 * r_i }" expect "k" vec![2., 3., 4., 6.], ones_matrix_dense: "I_ij { (0:2, 0:2): 1 }" expect "I" vec![1., 1., 1., 1.], diff --git a/src/parser/ds_parser.rs b/src/parser/ds_parser.rs index 61a8d66..c72226e 100644 --- a/src/parser/ds_parser.rs +++ b/src/parser/ds_parser.rs @@ -86,7 +86,16 @@ fn parse_value<'a, 'b>(pair: Pair<'a, Rule>) -> Ast<'a> { } else { None }; - let mut head_term = parse_value(inner.next().unwrap()); + let mut head_term = match sign { + Some(s) => Ast { + kind: AstKind::Monop(ast::Monop { + op: s, + child: Box::new(parse_value(inner.next().unwrap())), + }), + span, + }, + None => parse_value(inner.next().unwrap()) + }; while inner.peek().is_some() { //term_op = @{ "-"|"+" } let term_op = parse_sign(inner.next().unwrap()); @@ -104,17 +113,7 @@ fn parse_value<'a, 'b>(pair: Pair<'a, Rule>) -> Ast<'a> { span: subspan, }; } - if sign.is_some() { - Ast { - kind: AstKind::Monop(ast::Monop { - op: sign.unwrap(), - child: Box::new(head_term), - }), - span, - } - } else { - head_term - } + head_term } //term = { factor ~ (factor_op ~ factor)* }