@@ -3,163 +3,90 @@ use float_ord::FloatOrd;
33use simple_grad:: * ;
44
55fn vec_ops ( c : & mut Criterion ) {
6- let dims = 100 ;
6+ let dims = 256 ;
7+ let num_vecs = 10 ;
78 let mut embeddings = Vec :: new ( ) ;
9+ use_shared_pool ( false ) ;
810
9- for i in 0 ..100000 {
11+ for i in 0 ..num_vecs {
1012 let mut row = Vec :: with_capacity ( dims) ;
1113 for dim in 0 ..dims {
1214 row. push ( ( i* dim) as f32 ) ;
1315 }
1416 embeddings. push ( Variable :: new ( row) ) ;
1517 }
16- //c.bench_function("bench vecs", |b| b.iter(|| embeddings.clone().sum_all()));
1718
18- use_shared_pool ( false ) ;
19- let results = embeddings. clone ( ) . sum_all ( ) ;
20- c. bench_function ( "bench backward" , |b| b. iter ( || {
19+ c. bench_function ( "bench mul" , |b| b. iter ( || {
20+ let mut res = embeddings[ 0 ] . clone ( ) ;
21+ for i in 1 ..num_vecs {
22+ res = res * & embeddings[ i] ;
23+ }
24+ let res = res. sum ( ) ;
2125 let mut graph = Graph :: new ( ) ;
22- graph. backward ( & results ) ;
26+ graph. backward ( & res ) ;
2327 } ) ) ;
24- }
2528
26- fn bench_attention ( c : & mut Criterion ) {
27- let dims = 100 ;
28- let mut embeddings = Vec :: new ( ) ;
29-
30- for i in 0 ..20 {
31- let mut row = Vec :: with_capacity ( dims) ;
32- for dim in 0 ..dims {
33- row. push ( ( i* dim) as f32 ) ;
29+ c. bench_function ( "bench pow" , |b| b. iter ( || {
30+ let mut res = embeddings[ 0 ] . clone ( ) ;
31+ for i in 1 ..num_vecs {
32+ res = res. pow ( & embeddings[ i] ) ;
3433 }
35- embeddings. push ( ( Variable :: new ( row) , 1 ) ) ;
36- }
37- use_shared_pool ( false ) ;
38- c. bench_function ( "bench backward" , |b| b. iter ( || {
39- let e = attention_mean ( embeddings. iter ( ) , 20 , None ) ;
34+ let res = res. sum ( ) ;
4035 let mut graph = Graph :: new ( ) ;
41- graph. backward ( & e . sum ( ) ) ;
36+ graph. backward ( & res ) ;
4237 } ) ) ;
43- }
44-
45- pub fn attention_mean < ' a > (
46- it : impl Iterator < Item =& ' a ( ANode , usize ) > ,
47- attention_dims : usize ,
48- window : Option < usize >
49- ) -> ANode {
50-
51- let items: Vec < _ > = it. map ( |( node, count) | {
52- ( Attention :: new ( node, attention_dims) , * count)
53- } ) . collect ( ) ;
5438
55- if items. len ( ) == 1 {
56- return items[ 0 ] . 0 . value . clone ( )
57- }
58-
59- // Compute attention matrix
60- let attention_matrix = compute_attention_matrix ( & items, window) ;
61-
62- let att = compute_attention_softmax ( attention_matrix, attention_dims) ;
63-
64- let summed_weights = att. sum_all ( ) ;
65- let n = items. len ( ) as f32 ;
66- items. into_iter ( ) . enumerate ( )
67- . map ( |( i, ( at_i, _c) ) | at_i. value * summed_weights. slice ( i, 1 ) )
68- . collect :: < Vec < _ > > ( ) . sum_all ( ) / n
69- }
70-
71- fn compute_attention_matrix (
72- items : & [ ( Attention , usize ) ] ,
73- window : Option < usize >
74- ) -> Vec < Vec < ANode > > {
75-
76- // Get the attention for each feature
77- let zero = Constant :: scalar ( 0. ) ;
78- let mut scaled = vec ! [ vec![ zero; items. len( ) ] ; items. len( ) ] ;
79- for i in 0 ..items. len ( ) {
80- let ( j_start, j_end) = match window {
81- Some ( size) => {
82- let start = if size > i { 0 } else { i - size } ;
83- let stop = ( i + size + 1 ) . min ( items. len ( ) ) ;
84- ( start, stop)
85- } ,
86- None => ( 0 , items. len ( ) )
87- } ;
88-
89- let ( at_i, ic) = & items[ i] ;
90- let row = & mut scaled[ i] ;
91- for j in j_start..j_end {
92- let ( at_j, jc) = & items[ j] ;
93- let mut dot_i_j = ( & at_i. query ) . dot ( & at_j. key ) ;
94- let num = ic * jc;
95- if num >= 1 && window. is_none ( ) {
96- dot_i_j = dot_i_j * ( num as f32 ) ;
97- }
98- row[ j] = dot_i_j;
39+ c. bench_function ( "bench sub" , |b| b. iter ( || {
40+ let mut res = embeddings[ 0 ] . clone ( ) ;
41+ for i in 1 ..num_vecs {
42+ res = res - & embeddings[ i] ;
9943 }
100- }
101- scaled
102- }
103-
104-
105- fn compute_attention_softmax (
106- attention_matrix : Vec < Vec < ANode > > ,
107- d_k : usize
108- ) -> Vec < ANode > {
109- // Compute softmax
110- let d_k = Constant :: scalar ( ( d_k as f32 ) . sqrt ( ) ) ;
111-
112- // Compute softmax for each feature
113- let mut att = Vec :: with_capacity ( attention_matrix. len ( ) ) ;
114- for row in attention_matrix. into_iter ( ) {
115- let row = row. concat ( ) / & d_k;
116- let sm = softmax ( row) ;
117- att. push ( sm) ;
118- }
119-
120- att
121- }
122-
123- fn softmax ( numers : ANode ) -> ANode {
124-
125- let max_value = numers. value ( ) . iter ( )
126- . max_by_key ( |v| FloatOrd ( * * v) )
127- . expect ( "Shouldn't be non-zero!" ) ;
128- let mv = Constant :: scalar ( * max_value) ;
129- let n = ( numers - & mv) . exp ( ) ;
130- & n / n. sum ( )
131- }
44+ let res = res. sum ( ) ;
45+ let mut graph = Graph :: new ( ) ;
46+ graph. backward ( & res) ;
47+ } ) ) ;
13248
133- #[ derive( Clone ) ]
134- struct Attention {
135- query : ANode ,
136- key : ANode ,
137- value : ANode
138- }
49+ c. bench_function ( "bench add" , |b| b. iter ( || {
50+ let mut res = embeddings[ 0 ] . clone ( ) ;
51+ for i in 1 ..num_vecs {
52+ res = res + & embeddings[ i] ;
53+ }
54+ let res = res. sum ( ) ;
55+ let mut graph = Graph :: new ( ) ;
56+ graph. backward ( & res) ;
57+ } ) ) ;
13958
140- impl Attention {
141- fn new ( node : & ANode , attention_dims : usize ) -> Self {
142- let query = get_query_vec ( & node, attention_dims) ;
143- let key = get_key_vec ( & node, attention_dims) ;
144- let value = get_value_vec ( & node, attention_dims) ;
145- Attention { query, key, value}
146- }
147- }
59+ c. bench_function ( "bench sqrt" , |b| b. iter ( || {
60+ let mut res = embeddings[ 0 ] . clone ( ) ;
61+ for i in 1 ..num_vecs {
62+ res = res. pow ( 0.5 ) ;
63+ }
64+ let res = res. sum ( ) ;
65+ let mut graph = Graph :: new ( ) ;
66+ graph. backward ( & res) ;
67+ } ) ) ;
14868
149- fn get_value_vec ( emb : & ANode , dims : usize ) -> ANode {
150- let v = emb. value ( ) . len ( ) ;
151- emb. slice ( 2 * dims, v - 2 * dims)
152- }
69+ c. bench_function ( "bench pow2" , |b| b. iter ( || {
70+ let mut res = embeddings[ 0 ] . clone ( ) ;
71+ for i in 1 ..num_vecs {
72+ res = res. pow ( 2f32 ) ;
73+ }
74+ let res = res. sum ( ) ;
75+ let mut graph = Graph :: new ( ) ;
76+ graph. backward ( & res) ;
77+ } ) ) ;
15378
154- fn get_query_vec ( emb : & ANode , dims : usize ) -> ANode {
155- emb. slice ( 0 , dims)
156- }
79+ c. bench_function ( "bench exp" , |b| b. iter ( || {
80+ let mut res = embeddings[ 0 ] . clone ( ) ;
81+ for i in 1 ..num_vecs {
82+ res = res. exp ( ) ;
83+ }
84+ let res = res. sum ( ) ;
85+ let mut graph = Graph :: new ( ) ;
86+ graph. backward ( & res) ;
87+ } ) ) ;
15788
158- fn get_key_vec ( emb : & ANode , dims : usize ) -> ANode {
159- emb. slice ( dims, dims)
16089}
16190
162-
163- //criterion_group!(benches, vec_ops);
164- criterion_group ! ( benches, bench_attention) ;
91+ criterion_group ! ( benches, vec_ops) ;
16592criterion_main ! ( benches) ;
0 commit comments