11use std:: rc:: Rc ;
22
33use hashbrown:: HashMap ;
4+ use hashbrown:: hash_map:: Entry ;
45use crate :: { DType , ANode , NodeIdx , Node } ;
56use crate :: vecops:: iadd;
67use crate :: pool:: { allocate_vec, MPVec } ;
@@ -19,19 +20,22 @@ impl Graph {
1920 }
2021 }
2122
23+ #[ inline]
2224 pub fn debug_nan ( & mut self , check : bool ) {
2325 self . nan_check = check;
2426 }
25-
2627
28+ #[ inline]
2729 pub fn get_grad ( & self , node : & ANode ) -> Option < & Vec < DType > > {
2830 self . gradients . get ( & node. get_id ( ) ) . map ( |v| v. as_ref ( ) )
2931 }
3032
33+ #[ inline]
3134 pub fn zero_grads ( & mut self ) {
3235 self . gradients . clear ( ) ;
3336 }
3437
38+ #[ inline]
3539 pub fn clear_memory ( & mut self ) {
3640 self . gradients . clear ( ) ;
3741 }
@@ -45,39 +49,50 @@ impl Graph {
4549 }
4650 }
4751
52+ #[ inline]
4853 fn get_temp_space ( & mut self , size : usize ) -> MPVec {
4954 allocate_vec ( size)
5055 }
5156
57+ #[ inline]
5258 fn add_grad ( & mut self , node : & ANode , grad : MPVec ) {
5359 self . gradients . insert ( node. get_id ( ) , grad) ;
5460 }
61+
62+ #[ inline]
63+ fn add_or_update_grad ( & mut self , node : & ANode , grad : MPVec ) {
64+ match self . gradients . entry ( node. get_id ( ) ) {
65+ Entry :: Occupied ( mut entry) => {
66+ iadd ( entry. get_mut ( ) , grad. as_slice ( ) ) ;
67+ } ,
68+ Entry :: Vacant ( mut entry) => {
69+ entry. insert ( grad) ;
70+ }
71+ }
72+ }
73+
5574
56- pub fn backward ( & mut self , end_node : & ANode ) {
75+ pub fn backward ( & mut self , end_node : & ANode ) {
5776 let out = Run :: new ( end_node) ;
5877 // dz/dz of course is 1
5978 let mut z_grad = self . get_or_create_grad ( & out) ;
6079 z_grad. fill ( 1f32 ) ;
6180
6281 // Allocate once
63- let mut grads = Vec :: new ( ) ;
6482 let mut temp_grads = Vec :: new ( ) ;
6583 self . add_grad ( & out, z_grad) ;
66- self . recurse ( & out, & mut grads , & mut temp_grads) ;
84+ self . recurse ( & out, & mut temp_grads) ;
6785 }
6886
69- fn recurse ( & mut self , node : & ANode , grads : & mut Vec < MPVec > , temp_grads : & mut Vec < MPVec > ) {
87+ fn recurse ( & mut self , node : & ANode , temp_grads : & mut Vec < MPVec > ) {
7088 if !node. is_leaf ( ) {
7189 let node_grad = self . get_or_create_grad ( node) ;
7290 if let Some ( children) = node. get_children ( ) {
73- grads. clear ( ) ;
7491 temp_grads. clear ( ) ;
7592 // Grab gradients
76- grads. extend ( children. iter ( )
77- . map ( |c| self . get_or_create_grad ( c) ) ) ;
7893
79- temp_grads. extend ( grads . iter ( )
80- . map ( |g | self . get_temp_space ( g . len ( ) ) ) ) ;
94+ temp_grads. extend ( children . iter ( )
95+ . map ( |c | self . get_temp_space ( c . value ( ) . len ( ) ) ) ) ;
8196
8297 node. compute_grad ( & node_grad, temp_grads) ;
8398
@@ -93,13 +108,10 @@ impl Graph {
93108 }
94109
95110 // Update grads
96- grads. iter_mut ( ) . zip ( temp_grads. into_iter ( ) ) . for_each ( |( g, tg) | {
97- iadd ( g, & tg) ;
98- } ) ;
99111
100112 // Re-add gradients
101- children. iter ( ) . zip ( grads . drain ( ..) ) . for_each ( |( c, g) | {
102- self . add_grad ( c, g) ;
113+ children. iter ( ) . zip ( temp_grads . drain ( ..) ) . for_each ( |( c, g) | {
114+ self . add_or_update_grad ( c, g) ;
103115 } ) ;
104116
105117 if node. requires_grad ( ) {
@@ -108,7 +120,7 @@ impl Graph {
108120
109121 // Run children
110122 for child in children. iter ( ) {
111- self . recurse ( child, grads , temp_grads) ;
123+ self . recurse ( child, temp_grads) ;
112124 }
113125
114126 } else {
0 commit comments