@@ -13,19 +13,31 @@ use super::simple_passes::outgoing_edges;
13
13
use super :: { apply_rewrite_rules, id} ;
14
14
use rspirv:: dr:: { Block , Function , Instruction , ModuleHeader , Operand } ;
15
15
use rspirv:: spirv:: { Op , Word } ;
16
- use rustc_data_structures:: fx:: { FxHashMap , FxHashSet } ;
16
+ use rustc_data_structures:: fx:: { FxHashMap , FxHashSet , FxIndexMap } ;
17
17
use rustc_middle:: bug;
18
18
use std:: collections:: hash_map;
19
19
20
+ // HACK(eddyb) newtype instead of type alias to avoid mistakes.
21
+ #[ derive( Copy , Clone , PartialEq , Eq , Hash ) ]
22
+ struct LabelId ( Word ) ;
23
+
20
24
pub fn mem2reg (
21
25
header : & mut ModuleHeader ,
22
26
types_global_values : & mut Vec < Instruction > ,
23
27
pointer_to_pointee : & FxHashMap < Word , Word > ,
24
28
constants : & FxHashMap < Word , u32 > ,
25
29
func : & mut Function ,
26
30
) {
27
- let reachable = compute_reachable ( & func. blocks ) ;
28
- let preds = compute_preds ( & func. blocks , & reachable) ;
31
+ // HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but
32
+ // it's made completely irrelevant by SPIR-T so only applies to legacy code.
33
+ let mut blocks: FxIndexMap < _ , _ > = func
34
+ . blocks
35
+ . iter_mut ( )
36
+ . map ( |block| ( LabelId ( block. label_id ( ) . unwrap ( ) ) , block) )
37
+ . collect ( ) ;
38
+
39
+ let reachable = compute_reachable ( & blocks) ;
40
+ let preds = compute_preds ( & blocks, & reachable) ;
29
41
let idom = compute_idom ( & preds, & reachable) ;
30
42
let dominance_frontier = compute_dominance_frontier ( & preds, & idom) ;
31
43
loop {
@@ -34,31 +46,27 @@ pub fn mem2reg(
34
46
types_global_values,
35
47
pointer_to_pointee,
36
48
constants,
37
- & mut func . blocks ,
49
+ & mut blocks,
38
50
& dominance_frontier,
39
51
) ;
40
52
if !changed {
41
53
break ;
42
54
}
43
55
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
44
- super :: dce:: dce_phi ( func ) ;
56
+ super :: dce:: dce_phi ( & mut blocks ) ;
45
57
}
46
58
}
47
59
48
- fn label_to_index ( blocks : & [ Block ] , id : Word ) -> usize {
49
- blocks
50
- . iter ( )
51
- . position ( |b| b. label_id ( ) . unwrap ( ) == id)
52
- . unwrap ( )
53
- }
54
-
55
- fn compute_reachable ( blocks : & [ Block ] ) -> Vec < bool > {
56
- fn recurse ( blocks : & [ Block ] , reachable : & mut [ bool ] , block : usize ) {
60
+ fn compute_reachable ( blocks : & FxIndexMap < LabelId , & mut Block > ) -> Vec < bool > {
61
+ fn recurse ( blocks : & FxIndexMap < LabelId , & mut Block > , reachable : & mut [ bool ] , block : usize ) {
57
62
if !reachable[ block] {
58
63
reachable[ block] = true ;
59
- for dest_id in outgoing_edges ( & blocks[ block] ) {
60
- let dest_idx = label_to_index ( blocks, dest_id) ;
61
- recurse ( blocks, reachable, dest_idx) ;
64
+ for dest_id in outgoing_edges ( blocks[ block] ) {
65
+ recurse (
66
+ blocks,
67
+ reachable,
68
+ blocks. get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ,
69
+ ) ;
62
70
}
63
71
}
64
72
}
@@ -67,17 +75,19 @@ fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
67
75
reachable
68
76
}
69
77
70
- fn compute_preds ( blocks : & [ Block ] , reachable_blocks : & [ bool ] ) -> Vec < Vec < usize > > {
78
+ fn compute_preds (
79
+ blocks : & FxIndexMap < LabelId , & mut Block > ,
80
+ reachable_blocks : & [ bool ] ,
81
+ ) -> Vec < Vec < usize > > {
71
82
let mut result = vec ! [ vec![ ] ; blocks. len( ) ] ;
72
83
// Do not count unreachable blocks as valid preds of blocks
73
84
for ( source_idx, source) in blocks
74
- . iter ( )
85
+ . values ( )
75
86
. enumerate ( )
76
87
. filter ( |& ( b, _) | reachable_blocks[ b] )
77
88
{
78
89
for dest_id in outgoing_edges ( source) {
79
- let dest_idx = label_to_index ( blocks, dest_id) ;
80
- result[ dest_idx] . push ( source_idx) ;
90
+ result[ blocks. get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ] . push ( source_idx) ;
81
91
}
82
92
}
83
93
result
@@ -161,7 +171,7 @@ fn insert_phis_all(
161
171
types_global_values : & mut Vec < Instruction > ,
162
172
pointer_to_pointee : & FxHashMap < Word , Word > ,
163
173
constants : & FxHashMap < Word , u32 > ,
164
- blocks : & mut [ Block ] ,
174
+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
165
175
dominance_frontier : & [ FxHashSet < usize > ] ,
166
176
) -> bool {
167
177
let var_maps_and_types = blocks[ 0 ]
@@ -198,7 +208,11 @@ fn insert_phis_all(
198
208
rewrite_rules : FxHashMap :: default ( ) ,
199
209
} ;
200
210
renamer. rename ( 0 , None ) ;
201
- apply_rewrite_rules ( & renamer. rewrite_rules , blocks) ;
211
+ // FIXME(eddyb) shouldn't this full rescan of the function be done once?
212
+ apply_rewrite_rules (
213
+ & renamer. rewrite_rules ,
214
+ blocks. values_mut ( ) . map ( |block| & mut * * block) ,
215
+ ) ;
202
216
remove_nops ( blocks) ;
203
217
}
204
218
remove_old_variables ( blocks, & var_maps_and_types) ;
@@ -216,7 +230,7 @@ struct VarInfo {
216
230
fn collect_access_chains (
217
231
pointer_to_pointee : & FxHashMap < Word , Word > ,
218
232
constants : & FxHashMap < Word , u32 > ,
219
- blocks : & [ Block ] ,
233
+ blocks : & FxIndexMap < LabelId , & mut Block > ,
220
234
base_var : Word ,
221
235
base_var_ty : Word ,
222
236
) -> Option < FxHashMap < Word , VarInfo > > {
@@ -249,7 +263,7 @@ fn collect_access_chains(
249
263
// Loop in case a previous block references a later AccessChain
250
264
loop {
251
265
let mut changed = false ;
252
- for inst in blocks. iter ( ) . flat_map ( |b| & b. instructions ) {
266
+ for inst in blocks. values ( ) . flat_map ( |b| & b. instructions ) {
253
267
for ( index, op) in inst. operands . iter ( ) . enumerate ( ) {
254
268
if let Operand :: IdRef ( id) = op {
255
269
if variables. contains_key ( id) {
@@ -307,10 +321,10 @@ fn collect_access_chains(
307
321
// same var map (e.g. `s.x = s.y;`).
308
322
fn split_copy_memory (
309
323
header : & mut ModuleHeader ,
310
- blocks : & mut [ Block ] ,
324
+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
311
325
var_map : & FxHashMap < Word , VarInfo > ,
312
326
) {
313
- for block in blocks {
327
+ for block in blocks. values_mut ( ) {
314
328
let mut inst_index = 0 ;
315
329
while inst_index < block. instructions . len ( ) {
316
330
let inst = & block. instructions [ inst_index] ;
@@ -369,7 +383,7 @@ fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
369
383
}
370
384
371
385
fn insert_phis (
372
- blocks : & [ Block ] ,
386
+ blocks : & FxIndexMap < LabelId , & mut Block > ,
373
387
dominance_frontier : & [ FxHashSet < usize > ] ,
374
388
var_map : & FxHashMap < Word , VarInfo > ,
375
389
) -> FxHashSet < usize > {
@@ -378,7 +392,7 @@ fn insert_phis(
378
392
let mut ever_on_work_list = FxHashSet :: default ( ) ;
379
393
let mut work_list = Vec :: new ( ) ;
380
394
let mut blocks_with_phi = FxHashSet :: default ( ) ;
381
- for ( block_idx, block) in blocks. iter ( ) . enumerate ( ) {
395
+ for ( block_idx, block) in blocks. values ( ) . enumerate ( ) {
382
396
if has_store ( block, var_map) {
383
397
ever_on_work_list. insert ( block_idx) ;
384
398
work_list. push ( block_idx) ;
@@ -423,10 +437,10 @@ fn top_stack_or_undef(
423
437
}
424
438
}
425
439
426
- struct Renamer < ' a > {
440
+ struct Renamer < ' a , ' b > {
427
441
header : & ' a mut ModuleHeader ,
428
442
types_global_values : & ' a mut Vec < Instruction > ,
429
- blocks : & ' a mut [ Block ] ,
443
+ blocks : & ' a mut FxIndexMap < LabelId , & ' b mut Block > ,
430
444
blocks_with_phi : FxHashSet < usize > ,
431
445
base_var_type : Word ,
432
446
var_map : & ' a FxHashMap < Word , VarInfo > ,
@@ -436,7 +450,7 @@ struct Renamer<'a> {
436
450
rewrite_rules : FxHashMap < Word , Word > ,
437
451
}
438
452
439
- impl Renamer < ' _ > {
453
+ impl Renamer < ' _ , ' _ > {
440
454
// Returns the phi definition.
441
455
fn insert_phi_value ( & mut self , block : usize , from_block : usize ) -> Word {
442
456
let from_block_label = self . blocks [ from_block] . label_id ( ) . unwrap ( ) ;
@@ -558,9 +572,8 @@ impl Renamer<'_> {
558
572
}
559
573
}
560
574
561
- for dest_id in outgoing_edges ( & self . blocks [ block] ) . collect :: < Vec < _ > > ( ) {
562
- // TODO: Don't do this find
563
- let dest_idx = label_to_index ( self . blocks , dest_id) ;
575
+ for dest_id in outgoing_edges ( self . blocks [ block] ) . collect :: < Vec < _ > > ( ) {
576
+ let dest_idx = self . blocks . get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ;
564
577
self . rename ( dest_idx, Some ( block) ) ;
565
578
}
566
579
@@ -570,16 +583,16 @@ impl Renamer<'_> {
570
583
}
571
584
}
572
585
573
- fn remove_nops ( blocks : & mut [ Block ] ) {
574
- for block in blocks {
586
+ fn remove_nops ( blocks : & mut FxIndexMap < LabelId , & mut Block > ) {
587
+ for block in blocks. values_mut ( ) {
575
588
block
576
589
. instructions
577
590
. retain ( |inst| inst. class . opcode != Op :: Nop ) ;
578
591
}
579
592
}
580
593
581
594
fn remove_old_variables (
582
- blocks : & mut [ Block ] ,
595
+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
583
596
var_maps_and_types : & [ ( FxHashMap < u32 , VarInfo > , u32 ) ] ,
584
597
) {
585
598
blocks[ 0 ] . instructions . retain ( |inst| {
@@ -590,7 +603,7 @@ fn remove_old_variables(
590
603
. all ( |( var_map, _) | !var_map. contains_key ( & result_id) )
591
604
}
592
605
} ) ;
593
- for block in blocks {
606
+ for block in blocks. values_mut ( ) {
594
607
block. instructions . retain ( |inst| {
595
608
!matches ! ( inst. class. opcode, Op :: AccessChain | Op :: InBoundsAccessChain )
596
609
|| inst. operands . iter ( ) . all ( |op| {
0 commit comments