@@ -735,6 +735,116 @@ impl Node for BulkSum {
735735}
736736
737737
738+ pub ( crate ) struct Maximum ( NodeIdx , Vec < ANode > , Computation ) ;
739+
740+ impl Maximum {
741+ pub ( crate ) fn new ( left : ANode , right : ANode ) -> ANode {
742+ let idx = NodeIdx :: new ( ) ;
743+ let value = Maximum :: compute ( & left, & right) ;
744+ let node = Maximum ( idx, vec ! [ left, right] , Computation :: pooled ( value) ) ;
745+ ANode :: new ( Arc :: new ( node) )
746+ }
747+
748+ fn compute ( left : & ANode , right : & ANode ) -> MPVec {
749+ let ( lv, rv) = Broadcast :: from_pair ( left. value ( ) , right. value ( ) ) ;
750+ let mut out = allocate_vec ( lv. len ) ;
751+ out. iter_mut ( ) . zip ( lv. zip ( rv) ) . for_each ( |( oi, ( lvi, rvi) ) | {
752+ * oi = lvi. max ( * rvi)
753+ } ) ;
754+ out
755+ }
756+ }
757+
758+ impl Node for Maximum {
759+ fn get_id ( & self ) -> NodeIdx { self . 0 . clone ( ) }
760+
761+ fn get_children ( & self ) -> Option < & [ ANode ] > {
762+ Some ( self . 1 . as_slice ( ) )
763+ }
764+
765+ fn is_leaf ( & self ) -> bool { false }
766+
767+ fn value ( & self ) -> & [ DType ] {
768+ & self . 2 . get ( )
769+ }
770+
771+ fn requires_grad ( & self ) -> bool { false }
772+
773+ fn compute_grad ( & self , grad : & [ DType ] , child_grads : & mut [ Vec < DType > ] ) {
774+ // f(x,y) = x.max(y)
775+ let left = self . 1 [ 0 ] . value ( ) ;
776+ let right = self . 1 [ 1 ] . value ( ) ;
777+ let ( lv, rv) = Broadcast :: from_pair ( left, right) ;
778+ let ( left_grad, right_grad) = child_grads. split_at_mut ( 1 ) ;
779+ let mut left_out = Updater :: new ( & mut left_grad[ 0 ] , grad. len ( ) ) ;
780+ let mut right_out = Updater :: new ( & mut right_grad[ 0 ] , grad. len ( ) ) ;
781+ grad. iter ( ) . zip ( lv. zip ( rv) ) . for_each ( |( gi, ( xi, yi) ) | {
782+ if xi >= yi {
783+ left_out. add ( * gi) ;
784+ right_out. add ( 0f32 ) ;
785+ } else {
786+ right_out. add ( * gi) ;
787+ left_out. add ( 0f32 ) ;
788+ }
789+ } ) ;
790+ }
791+ }
792+
793+ pub ( crate ) struct Minimum ( NodeIdx , Vec < ANode > , Computation ) ;
794+
795+ impl Minimum {
796+ pub ( crate ) fn new ( left : ANode , right : ANode ) -> ANode {
797+ let idx = NodeIdx :: new ( ) ;
798+ let value = Minimum :: compute ( & left, & right) ;
799+ let node = Minimum ( idx, vec ! [ left, right] , Computation :: pooled ( value) ) ;
800+ ANode :: new ( Arc :: new ( node) )
801+ }
802+
803+ fn compute ( left : & ANode , right : & ANode ) -> MPVec {
804+ let ( lv, rv) = Broadcast :: from_pair ( left. value ( ) , right. value ( ) ) ;
805+ let mut out = allocate_vec ( lv. len ) ;
806+ out. iter_mut ( ) . zip ( lv. zip ( rv) ) . for_each ( |( oi, ( lvi, rvi) ) | {
807+ * oi = lvi. min ( * rvi)
808+ } ) ;
809+ out
810+ }
811+ }
812+
813+ impl Node for Minimum {
814+ fn get_id ( & self ) -> NodeIdx { self . 0 . clone ( ) }
815+
816+ fn get_children ( & self ) -> Option < & [ ANode ] > {
817+ Some ( self . 1 . as_slice ( ) )
818+ }
819+
820+ fn is_leaf ( & self ) -> bool { false }
821+
822+ fn value ( & self ) -> & [ DType ] {
823+ & self . 2 . get ( )
824+ }
825+
826+ fn requires_grad ( & self ) -> bool { false }
827+
828+ fn compute_grad ( & self , grad : & [ DType ] , child_grads : & mut [ Vec < DType > ] ) {
829+ // f(x,y) = x.max(y)
830+ let left = self . 1 [ 0 ] . value ( ) ;
831+ let right = self . 1 [ 1 ] . value ( ) ;
832+ let ( lv, rv) = Broadcast :: from_pair ( left, right) ;
833+ let ( left_grad, right_grad) = child_grads. split_at_mut ( 1 ) ;
834+ let mut left_out = Updater :: new ( & mut left_grad[ 0 ] , grad. len ( ) ) ;
835+ let mut right_out = Updater :: new ( & mut right_grad[ 0 ] , grad. len ( ) ) ;
836+ grad. iter ( ) . zip ( lv. zip ( rv) ) . for_each ( |( gi, ( xi, yi) ) | {
837+ if xi >= yi {
838+ right_out. add ( * gi) ;
839+ left_out. add ( 0f32 ) ;
840+ } else {
841+ left_out. add ( * gi) ;
842+ right_out. add ( 0f32 ) ;
843+ }
844+ } ) ;
845+ }
846+ }
847+
738848#[ cfg( test) ]
739849mod tests {
740850 use super :: * ;
@@ -900,6 +1010,38 @@ mod tests {
9001010 assert_eq ! ( grad, & [ -1. , -( -1f32 ) . exp( ) , -( -2f32 ) . exp( ) ] ) ;
9011011 }
9021012
1013+ #[ test]
1014+ fn test_maximum ( ) {
1015+ let x = Variable :: new ( vec ! [ 1. , 2. ] ) ;
1016+ let y = Variable :: new ( vec ! [ 3. , 5. ] ) ;
1017+
1018+ let out = ( & x) . pow ( 4f32 ) . maximum ( 2f32 * & y) ;
1019+
1020+ let mut graph = Graph :: new ( ) ;
1021+ graph. backward ( & out) ;
1022+
1023+ let x_grad = graph. get_grad ( & x) . unwrap ( ) ;
1024+ let y_grad = graph. get_grad ( & y) . unwrap ( ) ;
1025+ assert_eq ! ( x_grad, & [ 0f32 , 32f32 ] ) ;
1026+ assert_eq ! ( y_grad, & [ 2f32 , 0f32 ] ) ;
1027+ }
1028+
1029+ #[ test]
1030+ fn test_minimum ( ) {
1031+ let x = Variable :: new ( vec ! [ 1. , 2. ] ) ;
1032+ let y = Variable :: new ( vec ! [ 3. , 5. ] ) ;
1033+
1034+ let out = ( & x) . pow ( 4f32 ) . minimum ( 2f32 * & y) ;
1035+
1036+ let mut graph = Graph :: new ( ) ;
1037+ graph. backward ( & out) ;
1038+
1039+ let x_grad = graph. get_grad ( & x) . unwrap ( ) ;
1040+ let y_grad = graph. get_grad ( & y) . unwrap ( ) ;
1041+ assert_eq ! ( x_grad, & [ 4f32 , 0f32 ] ) ;
1042+ assert_eq ! ( y_grad, & [ 0f32 , 2f32 ] ) ;
1043+ }
1044+
9031045 #[ test]
9041046 fn test_backward_pass_simple1 ( ) {
9051047 // 2x
0 commit comments