Skip to content

Commit d395a10

Browse files
committed
Use Rcs, not Arcs
1 parent a85e924 commit d395a10

File tree

3 files changed

+31
-31
lines changed

3 files changed

+31
-31
lines changed

src/graph.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use hashbrown::HashMap;
2-
use std::sync::Arc;
1+
use std::rc::Rc;
32

3+
use hashbrown::HashMap;
44
use crate::{DType,ANode,NodeIdx,Node};
55
use crate::vecops::iadd;
66
use crate::pool::allocate_vec;
@@ -140,7 +140,7 @@ pub(crate) struct Run(NodeIdx, Vec<ANode>);
140140
impl Run {
141141
pub(crate) fn new(x: &ANode) -> ANode {
142142
let idx = NodeIdx::new();
143-
ANode::new(Arc::new(Run(idx, vec![x.clone()])))
143+
ANode::new(Rc::new(Run(idx, vec![x.clone()])))
144144
}
145145
}
146146

src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub use ops::{Variable,Constant};
1212
pub use pool::{clear_pool, use_shared_pool};
1313

1414
use std::sync::atomic::{AtomicUsize, Ordering};
15-
use std::sync::Arc;
15+
use std::rc::Rc;
1616
use std::ops::{Add,Sub,Mul,Div,Deref,Neg};
1717

1818
use crate::ops::*;
@@ -47,10 +47,10 @@ pub trait Node {
4747
}
4848

4949
#[derive(Clone)]
50-
pub struct ANode(Arc<dyn Node>);
50+
pub struct ANode(Rc<dyn Node>);
5151

5252
impl ANode {
53-
fn new(n: Arc<dyn Node>) -> Self {
53+
fn new(n: Rc<dyn Node>) -> Self {
5454
ANode(n)
5555
}
5656

@@ -98,7 +98,7 @@ impl FromConstant for Vec<f32> {
9898

9999

100100
impl Deref for ANode {
101-
type Target = Arc<dyn Node>;
101+
type Target = Rc<dyn Node>;
102102

103103
fn deref(&self) -> &Self::Target {
104104
&self.0

src/ops.rs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
use std::sync::Arc;
1+
use std::rc::Rc;
22

33
use crate::*;
44
use crate::vecops::{add, iadd, sub, isub, mul, imul, div};
55
use crate::pool::{MPVec,allocate_vec};
66

77
enum Data {
88
Owned(Vec<DType>),
9-
Shared(Arc<Vec<DType>>),
9+
Shared(Rc<Vec<DType>>),
1010
Pooled(MPVec)
1111
}
1212

@@ -19,7 +19,7 @@ impl Computation {
1919
Computation { value: Data::Owned(value) }
2020
}
2121

22-
fn shared(value: Arc<Vec<DType>>) -> Self {
22+
fn shared(value: Rc<Vec<DType>>) -> Self {
2323
Computation { value: Data::Shared(value) }
2424
}
2525

@@ -41,16 +41,16 @@ pub struct Variable(NodeIdx, Computation);
4141
impl Variable {
4242
pub fn new(value: Vec<DType>) -> ANode {
4343
let v = Variable(NodeIdx::new(), Computation::new(value));
44-
ANode::new(Arc::new(v))
44+
ANode::new(Rc::new(v))
4545
}
4646

4747
pub fn scalar(value: DType) -> ANode {
4848
Variable::new(vec![value])
4949
}
5050

51-
pub fn shared(value: Arc<Vec<DType>>) -> ANode {
51+
pub fn shared(value: Rc<Vec<DType>>) -> ANode {
5252
let v = Variable(NodeIdx::new(), Computation::shared(value));
53-
ANode::new(Arc::new(v))
53+
ANode::new(Rc::new(v))
5454
}
5555

5656
}
@@ -78,14 +78,14 @@ pub struct Constant(NodeIdx, Computation);
7878
impl Constant {
7979
pub fn new(value: Vec<DType>) -> ANode {
8080
let c = Constant(NodeIdx::new(), Computation::new(value));
81-
ANode::new(Arc::new(c))
81+
ANode::new(Rc::new(c))
8282
}
8383

8484
pub fn scalar(value: DType) -> ANode {
8585
let mut v = allocate_vec(1);
8686
v.as_mut()[0] = value;
8787
let c = Constant(NodeIdx::new(), Computation::pooled(v));
88-
ANode::new(Arc::new(c))
88+
ANode::new(Rc::new(c))
8989
}
9090

9191
}
@@ -182,7 +182,7 @@ impl AddN {
182182
let idx = NodeIdx::new();
183183
let value = AddN::compute(&left, &right);
184184
let node = AddN(idx, vec![left, right], Computation::pooled(value));
185-
ANode::new(Arc::new(node))
185+
ANode::new(Rc::new(node))
186186
}
187187

188188
fn compute(left: &ANode, right: &ANode) -> MPVec {
@@ -230,7 +230,7 @@ impl Subtract {
230230
let idx = NodeIdx::new();
231231
let value = Subtract::compute(&left, &right);
232232
let node = Subtract(idx, vec![left, right], Computation::pooled(value));
233-
ANode::new(Arc::new(node))
233+
ANode::new(Rc::new(node))
234234
}
235235

236236
fn compute(left: &ANode, right: &ANode) -> MPVec {
@@ -278,7 +278,7 @@ impl Multiply {
278278
let idx = NodeIdx::new();
279279
let value = Multiply::compute(&left, &right);
280280
let node = Multiply(idx, vec![left, right], Computation::pooled(value));
281-
ANode::new(Arc::new(node))
281+
ANode::new(Rc::new(node))
282282
}
283283

284284
fn compute(left: &ANode, right: &ANode) -> MPVec {
@@ -333,7 +333,7 @@ impl Divide {
333333
let idx = NodeIdx::new();
334334
let value = Divide::compute(&left, &right);
335335
let node = Divide(idx, vec![left, right], Computation::pooled(value));
336-
ANode::new(Arc::new(node))
336+
ANode::new(Rc::new(node))
337337
}
338338

339339
fn compute(left: &ANode, right: &ANode) -> MPVec {
@@ -386,7 +386,7 @@ impl Power {
386386
let idx = NodeIdx::new();
387387
let value = Power::compute(&base, &exp);
388388
let node = Power(idx, vec![base, exp], Computation::pooled(value));
389-
ANode::new(Arc::new(node))
389+
ANode::new(Rc::new(node))
390390
}
391391

392392
fn compute(left: &ANode, right: &ANode) -> MPVec {
@@ -445,7 +445,7 @@ impl SumVec {
445445
let idx = NodeIdx::new();
446446
let value = SumVec::compute(&vec);
447447
let node = SumVec(idx, vec![vec], Computation::pooled(value));
448-
ANode::new(Arc::new(node))
448+
ANode::new(Rc::new(node))
449449
}
450450

451451
fn compute(left: &ANode) -> MPVec {
@@ -487,7 +487,7 @@ impl Cos {
487487
let idx = NodeIdx::new();
488488
let value = Cos::compute(&vec);
489489
let node = Cos(idx, vec![vec], Computation::pooled(value));
490-
ANode::new(Arc::new(node))
490+
ANode::new(Rc::new(node))
491491
}
492492

493493
fn compute(left: &ANode) -> MPVec {
@@ -529,7 +529,7 @@ impl Sin {
529529
let idx = NodeIdx::new();
530530
let value = Sin::compute(&vec);
531531
let node = Sin(idx, vec![vec], Computation::pooled(value));
532-
ANode::new(Arc::new(node))
532+
ANode::new(Rc::new(node))
533533
}
534534

535535
fn compute(left: &ANode) -> MPVec {
@@ -572,7 +572,7 @@ impl Ln {
572572
let idx = NodeIdx::new();
573573
let value = Ln::compute(&vec);
574574
let node = Ln(idx, vec![vec], Computation::pooled(value));
575-
ANode::new(Arc::new(node))
575+
ANode::new(Rc::new(node))
576576
}
577577

578578
fn compute(left: &ANode) -> MPVec {
@@ -614,7 +614,7 @@ impl Exp {
614614
let idx = NodeIdx::new();
615615
let value = Exp::compute(&vec);
616616
let node = Exp(idx, vec![vec], Computation::pooled(value));
617-
ANode::new(Arc::new(node))
617+
ANode::new(Rc::new(node))
618618
}
619619

620620
fn compute(left: &ANode) -> MPVec {
@@ -656,7 +656,7 @@ impl Negate {
656656
let idx = NodeIdx::new();
657657
let value = Negate::compute(&vec);
658658
let node = Negate(idx, vec![vec], Computation::pooled(value));
659-
ANode::new(Arc::new(node))
659+
ANode::new(Rc::new(node))
660660
}
661661

662662
fn compute(left: &ANode) -> MPVec {
@@ -698,7 +698,7 @@ impl BulkSum {
698698
let children: Vec<_> = vecs.collect();
699699
let value = BulkSum::compute(&children);
700700
let node = BulkSum(idx, children, Computation::pooled(value));
701-
ANode::new(Arc::new(node))
701+
ANode::new(Rc::new(node))
702702
}
703703

704704
fn compute(xs: &[ANode]) -> MPVec {
@@ -742,7 +742,7 @@ impl Maximum {
742742
let idx = NodeIdx::new();
743743
let value = Maximum::compute(&left, &right);
744744
let node = Maximum(idx, vec![left, right], Computation::pooled(value));
745-
ANode::new(Arc::new(node))
745+
ANode::new(Rc::new(node))
746746
}
747747

748748
fn compute(left: &ANode, right: &ANode) -> MPVec {
@@ -797,7 +797,7 @@ impl Minimum {
797797
let idx = NodeIdx::new();
798798
let value = Minimum::compute(&left, &right);
799799
let node = Minimum(idx, vec![left, right], Computation::pooled(value));
800-
ANode::new(Arc::new(node))
800+
ANode::new(Rc::new(node))
801801
}
802802

803803
fn compute(left: &ANode, right: &ANode) -> MPVec {
@@ -1251,15 +1251,15 @@ mod tests {
12511251

12521252
#[test]
12531253
fn test_updateable() {
1254-
let mut v = Arc::new(vec![0f32, 0f32]);
1254+
let mut v = Rc::new(vec![0f32, 0f32]);
12551255
let mut graph = Graph::new();
12561256
let grad = {
12571257
let x = Variable::shared(v.clone());
12581258
let res = (&x + 3f32).pow(2f32) + 3f32;
12591259
graph.backward(&res);
12601260
graph.get_grad(&x)
12611261
};
1262-
let v = Arc::get_mut(&mut v).unwrap();
1262+
let v = Rc::get_mut(&mut v).unwrap();
12631263
assert_eq!(v, &mut [0f32, 0f32]);
12641264
}
12651265

0 commit comments

Comments
 (0)