Skip to content

Commit fc221a7

Browse files
committed
Adds updated bench algorithms
1 parent 2455b17 commit fc221a7

File tree

1 file changed

+62
-135
lines changed

1 file changed

+62
-135
lines changed

benches/bench_algos.rs

Lines changed: 62 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -3,163 +3,90 @@ use float_ord::FloatOrd;
33
use simple_grad::*;
44

55
fn vec_ops(c: &mut Criterion) {
6-
let dims = 100;
6+
let dims = 256;
7+
let num_vecs = 10;
78
let mut embeddings = Vec::new();
9+
use_shared_pool(false);
810

9-
for i in 0..100000 {
11+
for i in 0..num_vecs {
1012
let mut row = Vec::with_capacity(dims);
1113
for dim in 0..dims {
1214
row.push((i*dim) as f32);
1315
}
1416
embeddings.push(Variable::new(row));
1517
}
16-
//c.bench_function("bench vecs", |b| b.iter(|| embeddings.clone().sum_all()));
1718

18-
use_shared_pool(false);
19-
let results = embeddings.clone().sum_all();
20-
c.bench_function("bench backward", |b| b.iter(|| {
19+
c.bench_function("bench mul", |b| b.iter(|| {
20+
let mut res = embeddings[0].clone();
21+
for i in 1..num_vecs {
22+
res = res * &embeddings[i];
23+
}
24+
let res = res.sum();
2125
let mut graph = Graph::new();
22-
graph.backward(&results);
26+
graph.backward(&res);
2327
}));
24-
}
2528

26-
fn bench_attention(c: &mut Criterion) {
27-
let dims = 100;
28-
let mut embeddings = Vec::new();
29-
30-
for i in 0..20 {
31-
let mut row = Vec::with_capacity(dims);
32-
for dim in 0..dims {
33-
row.push((i*dim) as f32);
29+
c.bench_function("bench pow", |b| b.iter(|| {
30+
let mut res = embeddings[0].clone();
31+
for i in 1..num_vecs {
32+
res = res.pow(&embeddings[i]);
3433
}
35-
embeddings.push((Variable::new(row), 1));
36-
}
37-
use_shared_pool(false);
38-
c.bench_function("bench backward", |b| b.iter(|| {
39-
let e = attention_mean(embeddings.iter(), 20, None);
34+
let res = res.sum();
4035
let mut graph = Graph::new();
41-
graph.backward(&e.sum());
36+
graph.backward(&res);
4237
}));
43-
}
44-
45-
pub fn attention_mean<'a>(
46-
it: impl Iterator<Item=&'a (ANode, usize)>,
47-
attention_dims: usize,
48-
window: Option<usize>
49-
) -> ANode {
50-
51-
let items: Vec<_> = it.map(|(node, count)| {
52-
(Attention::new(node, attention_dims), *count)
53-
}).collect();
5438

55-
if items.len() == 1 {
56-
return items[0].0.value.clone()
57-
}
58-
59-
// Compute attention matrix
60-
let attention_matrix = compute_attention_matrix(&items, window);
61-
62-
let att = compute_attention_softmax(attention_matrix, attention_dims);
63-
64-
let summed_weights = att.sum_all();
65-
let n = items.len() as f32;
66-
items.into_iter().enumerate()
67-
.map(|(i, (at_i, _c))| at_i.value * summed_weights.slice(i, 1))
68-
.collect::<Vec<_>>().sum_all() / n
69-
}
70-
71-
fn compute_attention_matrix(
72-
items: &[(Attention, usize)],
73-
window: Option<usize>
74-
) -> Vec<Vec<ANode>> {
75-
76-
// Get the attention for each feature
77-
let zero = Constant::scalar(0.);
78-
let mut scaled = vec![vec![zero; items.len()]; items.len()];
79-
for i in 0..items.len() {
80-
let (j_start, j_end) = match window {
81-
Some(size) => {
82-
let start = if size > i { 0 } else {i - size };
83-
let stop = (i + size + 1).min(items.len());
84-
(start, stop)
85-
},
86-
None => (0, items.len())
87-
};
88-
89-
let (at_i, ic) = &items[i];
90-
let row = &mut scaled[i];
91-
for j in j_start..j_end {
92-
let (at_j, jc) = &items[j];
93-
let mut dot_i_j = (&at_i.query).dot(&at_j.key);
94-
let num = ic * jc;
95-
if num >= 1 && window.is_none() {
96-
dot_i_j = dot_i_j * (num as f32);
97-
}
98-
row[j] = dot_i_j;
39+
c.bench_function("bench sub", |b| b.iter(|| {
40+
let mut res = embeddings[0].clone();
41+
for i in 1..num_vecs {
42+
res = res - &embeddings[i];
9943
}
100-
}
101-
scaled
102-
}
103-
104-
105-
fn compute_attention_softmax(
106-
attention_matrix: Vec<Vec<ANode>>,
107-
d_k: usize
108-
) -> Vec<ANode> {
109-
// Compute softmax
110-
let d_k = Constant::scalar((d_k as f32).sqrt());
111-
112-
// Compute softmax for each feature
113-
let mut att = Vec::with_capacity(attention_matrix.len());
114-
for row in attention_matrix.into_iter() {
115-
let row = row.concat() / &d_k;
116-
let sm = softmax(row);
117-
att.push(sm);
118-
}
119-
120-
att
121-
}
122-
123-
fn softmax(numers: ANode) -> ANode {
124-
125-
let max_value = numers.value().iter()
126-
.max_by_key(|v| FloatOrd(**v))
127-
.expect("Shouldn't be non-zero!");
128-
let mv = Constant::scalar(*max_value);
129-
let n = (numers - &mv).exp();
130-
&n / n.sum()
131-
}
44+
let res = res.sum();
45+
let mut graph = Graph::new();
46+
graph.backward(&res);
47+
}));
13248

133-
#[derive(Clone)]
134-
struct Attention {
135-
query: ANode,
136-
key: ANode,
137-
value: ANode
138-
}
49+
c.bench_function("bench add", |b| b.iter(|| {
50+
let mut res = embeddings[0].clone();
51+
for i in 1..num_vecs {
52+
res = res + &embeddings[i];
53+
}
54+
let res = res.sum();
55+
let mut graph = Graph::new();
56+
graph.backward(&res);
57+
}));
13958

140-
impl Attention {
141-
fn new(node: &ANode, attention_dims: usize) -> Self {
142-
let query = get_query_vec(&node, attention_dims);
143-
let key = get_key_vec(&node, attention_dims);
144-
let value = get_value_vec(&node, attention_dims);
145-
Attention {query, key, value}
146-
}
147-
}
59+
c.bench_function("bench sqrt", |b| b.iter(|| {
60+
let mut res = embeddings[0].clone();
61+
for i in 1..num_vecs {
62+
res = res.pow(0.5);
63+
}
64+
let res = res.sum();
65+
let mut graph = Graph::new();
66+
graph.backward(&res);
67+
}));
14868

149-
fn get_value_vec(emb: &ANode, dims: usize) -> ANode {
150-
let v = emb.value().len();
151-
emb.slice(2*dims, v - 2*dims)
152-
}
69+
c.bench_function("bench pow2", |b| b.iter(|| {
70+
let mut res = embeddings[0].clone();
71+
for i in 1..num_vecs {
72+
res = res.pow(2f32);
73+
}
74+
let res = res.sum();
75+
let mut graph = Graph::new();
76+
graph.backward(&res);
77+
}));
15378

154-
fn get_query_vec(emb: &ANode, dims: usize) -> ANode {
155-
emb.slice(0, dims)
156-
}
79+
c.bench_function("bench exp", |b| b.iter(|| {
80+
let mut res = embeddings[0].clone();
81+
for i in 1..num_vecs {
82+
res = res.exp();
83+
}
84+
let res = res.sum();
85+
let mut graph = Graph::new();
86+
graph.backward(&res);
87+
}));
15788

158-
fn get_key_vec(emb: &ANode, dims: usize) -> ANode {
159-
emb.slice(dims, dims)
16089
}
16190

162-
163-
//criterion_group!(benches, vec_ops);
164-
criterion_group!(benches, bench_attention);
91+
criterion_group!(benches, vec_ops);
16592
criterion_main!(benches);

0 commit comments

Comments
 (0)