@@ -272,9 +272,9 @@ macro_rules! run_unary_op {
272272 if left_len == out_len {
273273 $func( ArrayInput ( $left) , ArrayOutput ( $out) ) ;
274274 } else if left_len == 1 {
275- $func( BroadcastInput ( $left[ 0 ] , out_len) , ArrayOutput ( $out) ) ;
275+ $func( BroadcastInput ( $left, out_len) , ArrayOutput ( $out) ) ;
276276 } else if out_len == 1 {
277- $func( ArrayInput ( $left[ 0 ] , out_len ) , BroadcastOutput ( $out: tt , left_len) ) ;
277+ $func( ArrayInput ( $left) , BroadcastOutput ( $out, left_len) ) ;
278278 } else {
279279 panic!( "Left length: {}, Output Length: {}" , left_len, out_len) ;
280280 }
@@ -403,11 +403,11 @@ impl Subtract {
403403 }
404404
405405 fn compute ( left : & ANode , right : & ANode ) -> MPVec {
406- let ( lv , rv ) = Broadcast :: from_pair ( left. value ( ) , right . value ( ) ) ;
407- let mut out = allocate_vec ( lv . len ) ;
408- out. iter_mut ( ) . zip ( lv . zip ( rv ) ) . for_each ( | ( oi , ( lvi , rvi ) ) | {
409- * oi = lvi - rvi
410- } ) ;
406+ let x = left. value ( ) ;
407+ let y = right . value ( ) ;
408+ let mut out = Broadcast :: allocate_out ( x , y ) ;
409+ let o = & mut out ;
410+ run_binary_op ! ( x , y , o , simd_sub ) ;
411411 out
412412 }
413413}
@@ -431,12 +431,12 @@ impl Node for Subtract {
431431 fn compute_grad ( & self , grad : & [ DType ] , child_grads : & mut [ & mut [ DType ] ] ) {
432432 // f(x,y) = x - y
433433 // df(x,y)/dx = 1
434- // df(x,y)/dy = -1
435- let mut out = Updater :: new ( & mut child_grads[ 0 ] , grad. len ( ) ) ;
436- grad. iter ( ) . for_each ( |gi| out. add ( * gi) ) ;
434+ let out = & mut child_grads[ 0 ] ;
435+ run_unary_op ! ( grad, out, simd_iadd) ;
437436
438- let mut out = Updater :: new ( & mut child_grads[ 1 ] , grad. len ( ) ) ;
439- grad. iter ( ) . for_each ( |gi| out. add ( -* gi) ) ;
437+ // df(x,y)/dy = -1
438+ let out = & mut child_grads[ 1 ] ;
439+ run_unary_op ! ( grad, out, grad_sub_y) ;
440440 }
441441
442442}
@@ -539,21 +539,10 @@ impl Node for Divide {
539539 let out = & mut child_grads[ 0 ] ;
540540 run_binary_op ! ( grad, y, out, grad_div_x) ;
541541
542- /*
543- let ly = Broadcast::sized(y, child_grads[0].len());
544- let mut out = Updater::new(&mut child_grads[0], grad.len());
545- grad.iter().zip(ly).for_each(|(gi, yi)| out.add(*gi / *yi));
546- */
547-
548542 let out = & mut child_grads[ 1 ] ;
543+ // df(x,y)/dy = -x / y ^ 2
549544 run_trinary_op ! ( grad, x, y, out, grad_div_y) ;
550545
551- // df(x,y)/dy = -x / y ^ 2
552- /*
553- let (lx, ly) = Broadcast::from_pair(x, y);
554- let mut out = Updater::new(&mut child_grads[1], lx.len);
555- grad.iter().zip(lx.zip(ly)).for_each(|(gi, (xi, yi))| out.add(*gi * -*xi / yi.powf(2f32)));
556- */
557546 }
558547
559548}
@@ -664,7 +653,7 @@ impl Node for SquareRoot {
664653 fn requires_grad ( & self ) -> bool { false }
665654
666655 fn compute_grad ( & self , grad : & [ DType ] , child_grads : & mut [ & mut [ DType ] ] ) {
667- let x = self . 1 [ 0 ] . value ( ) ;
656+ let x = self . value ( ) ;
668657
669658 // df(x)/dx = (1/2) / x ^ 0.5
670659 child_grads[ 0 ] . iter_mut ( ) . zip ( grad. iter ( ) . zip ( x) ) . for_each ( |( outi, ( gi, xi) ) | {
@@ -905,7 +894,9 @@ impl Exp {
905894 fn compute ( left : & ANode ) -> MPVec {
906895 let lv = left. value ( ) ;
907896 let mut out = allocate_vec ( lv. len ( ) ) ;
908- out. iter_mut ( ) . zip ( lv. iter ( ) ) . for_each ( |( oi, lvi) | * oi = lvi. exp ( ) ) ;
897+ let o = & mut out;
898+ run_unary_op ! ( lv, o, simd_exp) ;
899+ //out.iter_mut().zip(lv.iter()).for_each(|(oi, lvi)| *oi = lvi.exp());
909900 out
910901 }
911902
@@ -1314,6 +1305,21 @@ mod tests {
13141305 assert_eq ! ( y_grad, & [ 3. ] ) ;
13151306 }
13161307
1308+ #[ test]
1309+ fn test_sqrt ( ) {
1310+ let x = Variable :: new ( vec ! [ 4. , 9. ] ) ;
1311+ let res = SquareRoot :: new ( x. clone ( ) ) ;
1312+ assert_eq ! ( res. value( ) , & [ 2. , 3. ] ) ;
1313+
1314+ let mut graph = Graph :: new ( ) ;
1315+ graph. backward ( & res) ;
1316+
1317+ let x_1_g = 1f32 / ( 2f32 * 2f32 ) ;
1318+ let x_2_g = 1f32 / ( 2f32 * 3f32 ) ;
1319+ let x_grad = graph. get_grad ( & x) . unwrap ( ) ;
1320+ assert_eq ! ( x_grad, & [ x_1_g, x_2_g] ) ;
1321+ }
1322+
13171323 #[ test]
13181324 fn test_div ( ) {
13191325 let x = Variable :: new ( vec ! [ 0. , 1. ] ) ;
@@ -1466,7 +1472,7 @@ mod tests {
14661472 let x = Variable :: new ( vec ! [ 1. , 2. , 3. ] ) ;
14671473
14681474 let x_slice = x. slice ( 1 , 2 ) ;
1469- let mut out = x_slice * 2. ;
1475+ let out = x_slice * 2. ;
14701476
14711477 let mut graph = Graph :: new ( ) ;
14721478 graph. backward ( & out) ;
0 commit comments