1- use std:: sync :: Arc ;
1+ use std:: rc :: Rc ;
22
33use crate :: * ;
44use crate :: vecops:: { add, iadd, sub, isub, mul, imul, div} ;
55use crate :: pool:: { MPVec , allocate_vec} ;
66
77enum Data {
88 Owned ( Vec < DType > ) ,
9- Shared ( Arc < Vec < DType > > ) ,
9+ Shared ( Rc < Vec < DType > > ) ,
1010 Pooled ( MPVec )
1111}
1212
@@ -19,7 +19,7 @@ impl Computation {
1919 Computation { value : Data :: Owned ( value) }
2020 }
2121
22- fn shared ( value : Arc < Vec < DType > > ) -> Self {
22+ fn shared ( value : Rc < Vec < DType > > ) -> Self {
2323 Computation { value : Data :: Shared ( value) }
2424 }
2525
@@ -41,16 +41,16 @@ pub struct Variable(NodeIdx, Computation);
4141impl Variable {
4242 pub fn new ( value : Vec < DType > ) -> ANode {
4343 let v = Variable ( NodeIdx :: new ( ) , Computation :: new ( value) ) ;
44- ANode :: new ( Arc :: new ( v) )
44+ ANode :: new ( Rc :: new ( v) )
4545 }
4646
4747 pub fn scalar ( value : DType ) -> ANode {
4848 Variable :: new ( vec ! [ value] )
4949 }
5050
51- pub fn shared ( value : Arc < Vec < DType > > ) -> ANode {
51+ pub fn shared ( value : Rc < Vec < DType > > ) -> ANode {
5252 let v = Variable ( NodeIdx :: new ( ) , Computation :: shared ( value) ) ;
53- ANode :: new ( Arc :: new ( v) )
53+ ANode :: new ( Rc :: new ( v) )
5454 }
5555
5656}
@@ -78,14 +78,14 @@ pub struct Constant(NodeIdx, Computation);
7878impl Constant {
7979 pub fn new ( value : Vec < DType > ) -> ANode {
8080 let c = Constant ( NodeIdx :: new ( ) , Computation :: new ( value) ) ;
81- ANode :: new ( Arc :: new ( c) )
81+ ANode :: new ( Rc :: new ( c) )
8282 }
8383
8484 pub fn scalar ( value : DType ) -> ANode {
8585 let mut v = allocate_vec ( 1 ) ;
8686 v. as_mut ( ) [ 0 ] = value;
8787 let c = Constant ( NodeIdx :: new ( ) , Computation :: pooled ( v) ) ;
88- ANode :: new ( Arc :: new ( c) )
88+ ANode :: new ( Rc :: new ( c) )
8989 }
9090
9191}
@@ -182,7 +182,7 @@ impl AddN {
182182 let idx = NodeIdx :: new ( ) ;
183183 let value = AddN :: compute ( & left, & right) ;
184184 let node = AddN ( idx, vec ! [ left, right] , Computation :: pooled ( value) ) ;
185- ANode :: new ( Arc :: new ( node) )
185+ ANode :: new ( Rc :: new ( node) )
186186 }
187187
188188 fn compute ( left : & ANode , right : & ANode ) -> MPVec {
@@ -230,7 +230,7 @@ impl Subtract {
230230 let idx = NodeIdx :: new ( ) ;
231231 let value = Subtract :: compute ( & left, & right) ;
232232 let node = Subtract ( idx, vec ! [ left, right] , Computation :: pooled ( value) ) ;
233- ANode :: new ( Arc :: new ( node) )
233+ ANode :: new ( Rc :: new ( node) )
234234 }
235235
236236 fn compute ( left : & ANode , right : & ANode ) -> MPVec {
@@ -278,7 +278,7 @@ impl Multiply {
278278 let idx = NodeIdx :: new ( ) ;
279279 let value = Multiply :: compute ( & left, & right) ;
280280 let node = Multiply ( idx, vec ! [ left, right] , Computation :: pooled ( value) ) ;
281- ANode :: new ( Arc :: new ( node) )
281+ ANode :: new ( Rc :: new ( node) )
282282 }
283283
284284 fn compute ( left : & ANode , right : & ANode ) -> MPVec {
@@ -333,7 +333,7 @@ impl Divide {
333333 let idx = NodeIdx :: new ( ) ;
334334 let value = Divide :: compute ( & left, & right) ;
335335 let node = Divide ( idx, vec ! [ left, right] , Computation :: pooled ( value) ) ;
336- ANode :: new ( Arc :: new ( node) )
336+ ANode :: new ( Rc :: new ( node) )
337337 }
338338
339339 fn compute ( left : & ANode , right : & ANode ) -> MPVec {
@@ -386,7 +386,7 @@ impl Power {
386386 let idx = NodeIdx :: new ( ) ;
387387 let value = Power :: compute ( & base, & exp) ;
388388 let node = Power ( idx, vec ! [ base, exp] , Computation :: pooled ( value) ) ;
389- ANode :: new ( Arc :: new ( node) )
389+ ANode :: new ( Rc :: new ( node) )
390390 }
391391
392392 fn compute ( left : & ANode , right : & ANode ) -> MPVec {
@@ -445,7 +445,7 @@ impl SumVec {
445445 let idx = NodeIdx :: new ( ) ;
446446 let value = SumVec :: compute ( & vec) ;
447447 let node = SumVec ( idx, vec ! [ vec] , Computation :: pooled ( value) ) ;
448- ANode :: new ( Arc :: new ( node) )
448+ ANode :: new ( Rc :: new ( node) )
449449 }
450450
451451 fn compute ( left : & ANode ) -> MPVec {
@@ -487,7 +487,7 @@ impl Cos {
487487 let idx = NodeIdx :: new ( ) ;
488488 let value = Cos :: compute ( & vec) ;
489489 let node = Cos ( idx, vec ! [ vec] , Computation :: pooled ( value) ) ;
490- ANode :: new ( Arc :: new ( node) )
490+ ANode :: new ( Rc :: new ( node) )
491491 }
492492
493493 fn compute ( left : & ANode ) -> MPVec {
@@ -529,7 +529,7 @@ impl Sin {
529529 let idx = NodeIdx :: new ( ) ;
530530 let value = Sin :: compute ( & vec) ;
531531 let node = Sin ( idx, vec ! [ vec] , Computation :: pooled ( value) ) ;
532- ANode :: new ( Arc :: new ( node) )
532+ ANode :: new ( Rc :: new ( node) )
533533 }
534534
535535 fn compute ( left : & ANode ) -> MPVec {
@@ -572,7 +572,7 @@ impl Ln {
572572 let idx = NodeIdx :: new ( ) ;
573573 let value = Ln :: compute ( & vec) ;
574574 let node = Ln ( idx, vec ! [ vec] , Computation :: pooled ( value) ) ;
575- ANode :: new ( Arc :: new ( node) )
575+ ANode :: new ( Rc :: new ( node) )
576576 }
577577
578578 fn compute ( left : & ANode ) -> MPVec {
@@ -614,7 +614,7 @@ impl Exp {
614614 let idx = NodeIdx :: new ( ) ;
615615 let value = Exp :: compute ( & vec) ;
616616 let node = Exp ( idx, vec ! [ vec] , Computation :: pooled ( value) ) ;
617- ANode :: new ( Arc :: new ( node) )
617+ ANode :: new ( Rc :: new ( node) )
618618 }
619619
620620 fn compute ( left : & ANode ) -> MPVec {
@@ -656,7 +656,7 @@ impl Negate {
656656 let idx = NodeIdx :: new ( ) ;
657657 let value = Negate :: compute ( & vec) ;
658658 let node = Negate ( idx, vec ! [ vec] , Computation :: pooled ( value) ) ;
659- ANode :: new ( Arc :: new ( node) )
659+ ANode :: new ( Rc :: new ( node) )
660660 }
661661
662662 fn compute ( left : & ANode ) -> MPVec {
@@ -698,7 +698,7 @@ impl BulkSum {
698698 let children: Vec < _ > = vecs. collect ( ) ;
699699 let value = BulkSum :: compute ( & children) ;
700700 let node = BulkSum ( idx, children, Computation :: pooled ( value) ) ;
701- ANode :: new ( Arc :: new ( node) )
701+ ANode :: new ( Rc :: new ( node) )
702702 }
703703
704704 fn compute ( xs : & [ ANode ] ) -> MPVec {
@@ -742,7 +742,7 @@ impl Maximum {
742742 let idx = NodeIdx :: new ( ) ;
743743 let value = Maximum :: compute ( & left, & right) ;
744744 let node = Maximum ( idx, vec ! [ left, right] , Computation :: pooled ( value) ) ;
745- ANode :: new ( Arc :: new ( node) )
745+ ANode :: new ( Rc :: new ( node) )
746746 }
747747
748748 fn compute ( left : & ANode , right : & ANode ) -> MPVec {
@@ -797,7 +797,7 @@ impl Minimum {
797797 let idx = NodeIdx :: new ( ) ;
798798 let value = Minimum :: compute ( & left, & right) ;
799799 let node = Minimum ( idx, vec ! [ left, right] , Computation :: pooled ( value) ) ;
800- ANode :: new ( Arc :: new ( node) )
800+ ANode :: new ( Rc :: new ( node) )
801801 }
802802
803803 fn compute ( left : & ANode , right : & ANode ) -> MPVec {
@@ -1251,15 +1251,15 @@ mod tests {
12511251
12521252 #[ test]
12531253 fn test_updateable ( ) {
1254- let mut v = Arc :: new ( vec ! [ 0f32 , 0f32 ] ) ;
1254+ let mut v = Rc :: new ( vec ! [ 0f32 , 0f32 ] ) ;
12551255 let mut graph = Graph :: new ( ) ;
12561256 let grad = {
12571257 let x = Variable :: shared ( v. clone ( ) ) ;
12581258 let res = ( & x + 3f32 ) . pow ( 2f32 ) + 3f32 ;
12591259 graph. backward ( & res) ;
12601260 graph. get_grad ( & x)
12611261 } ;
1262- let v = Arc :: get_mut ( & mut v) . unwrap ( ) ;
1262+ let v = Rc :: get_mut ( & mut v) . unwrap ( ) ;
12631263 assert_eq ! ( v, & mut [ 0f32 , 0f32 ] ) ;
12641264 }
12651265
0 commit comments