Skip to content

Commit fc0db1b

Browse files
eddybFirestar99
authored andcommitted
WIP: mem2reg speedup
1 parent 3f96056 commit fc0db1b

File tree

3 files changed

+62
-45
lines changed

3 files changed

+62
-45
lines changed

crates/rustc_codegen_spirv/src/linker/dce.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
//! *references* a rooted thing is also rooted, not the other way around - but that's the basic
88
//! concept.
99
10-
use rspirv::dr::{Function, Instruction, Module, Operand};
10+
use rspirv::dr::{Block, Function, Instruction, Module, Operand};
1111
use rspirv::spirv::{Decoration, LinkageType, Op, StorageClass, Word};
12-
use rustc_data_structures::fx::FxIndexSet;
12+
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
13+
use std::hash::Hash;
1314

1415
pub fn dce(module: &mut Module) {
1516
let mut rooted = collect_roots(module);
@@ -137,11 +138,11 @@ fn kill_unrooted(module: &mut Module, rooted: &FxIndexSet<Word>) {
137138
}
138139
}
139140

140-
pub fn dce_phi(func: &mut Function) {
141+
pub fn dce_phi(blocks: &mut FxIndexMap<impl Eq + Hash, &mut Block>) {
141142
let mut used = FxIndexSet::default();
142143
loop {
143144
let mut changed = false;
144-
for inst in func.all_inst_iter() {
145+
for inst in blocks.values().flat_map(|block| &block.instructions) {
145146
if inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()) {
146147
for op in &inst.operands {
147148
if let Some(id) = op.id_ref_any() {
@@ -154,7 +155,7 @@ pub fn dce_phi(func: &mut Function) {
154155
break;
155156
}
156157
}
157-
for block in &mut func.blocks {
158+
for block in blocks.values_mut() {
158159
block
159160
.instructions
160161
.retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()));

crates/rustc_codegen_spirv/src/linker/mem2reg.rs

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,31 @@ use super::simple_passes::outgoing_edges;
1313
use super::{apply_rewrite_rules, id};
1414
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
1515
use rspirv::spirv::{Op, Word};
16-
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
16+
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap};
1717
use rustc_middle::bug;
1818
use std::collections::hash_map;
1919

20+
// HACK(eddyb) newtype instead of type alias to avoid mistakes.
21+
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
22+
struct LabelId(Word);
23+
2024
pub fn mem2reg(
2125
header: &mut ModuleHeader,
2226
types_global_values: &mut Vec<Instruction>,
2327
pointer_to_pointee: &FxHashMap<Word, Word>,
2428
constants: &FxHashMap<Word, u32>,
2529
func: &mut Function,
2630
) {
27-
let reachable = compute_reachable(&func.blocks);
28-
let preds = compute_preds(&func.blocks, &reachable);
31+
// HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but
32+
// it's made completely irrelevant by SPIR-T so only applies to legacy code.
33+
let mut blocks: FxIndexMap<_, _> = func
34+
.blocks
35+
.iter_mut()
36+
.map(|block| (LabelId(block.label_id().unwrap()), block))
37+
.collect();
38+
39+
let reachable = compute_reachable(&blocks);
40+
let preds = compute_preds(&blocks, &reachable);
2941
let idom = compute_idom(&preds, &reachable);
3042
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
3143
loop {
@@ -34,31 +46,27 @@ pub fn mem2reg(
3446
types_global_values,
3547
pointer_to_pointee,
3648
constants,
37-
&mut func.blocks,
49+
&mut blocks,
3850
&dominance_frontier,
3951
);
4052
if !changed {
4153
break;
4254
}
4355
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
44-
super::dce::dce_phi(func);
56+
super::dce::dce_phi(&mut blocks);
4557
}
4658
}
4759

48-
fn label_to_index(blocks: &[Block], id: Word) -> usize {
49-
blocks
50-
.iter()
51-
.position(|b| b.label_id().unwrap() == id)
52-
.unwrap()
53-
}
54-
55-
fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
56-
fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) {
60+
fn compute_reachable(blocks: &FxIndexMap<LabelId, &mut Block>) -> Vec<bool> {
61+
fn recurse(blocks: &FxIndexMap<LabelId, &mut Block>, reachable: &mut [bool], block: usize) {
5762
if !reachable[block] {
5863
reachable[block] = true;
59-
for dest_id in outgoing_edges(&blocks[block]) {
60-
let dest_idx = label_to_index(blocks, dest_id);
61-
recurse(blocks, reachable, dest_idx);
64+
for dest_id in outgoing_edges(blocks[block]) {
65+
recurse(
66+
blocks,
67+
reachable,
68+
blocks.get_index_of(&LabelId(dest_id)).unwrap(),
69+
);
6270
}
6371
}
6472
}
@@ -67,17 +75,19 @@ fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
6775
reachable
6876
}
6977

70-
fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec<Vec<usize>> {
78+
fn compute_preds(
79+
blocks: &FxIndexMap<LabelId, &mut Block>,
80+
reachable_blocks: &[bool],
81+
) -> Vec<Vec<usize>> {
7182
let mut result = vec![vec![]; blocks.len()];
7283
// Do not count unreachable blocks as valid preds of blocks
7384
for (source_idx, source) in blocks
74-
.iter()
85+
.values()
7586
.enumerate()
7687
.filter(|&(b, _)| reachable_blocks[b])
7788
{
7889
for dest_id in outgoing_edges(source) {
79-
let dest_idx = label_to_index(blocks, dest_id);
80-
result[dest_idx].push(source_idx);
90+
result[blocks.get_index_of(&LabelId(dest_id)).unwrap()].push(source_idx);
8191
}
8292
}
8393
result
@@ -161,7 +171,7 @@ fn insert_phis_all(
161171
types_global_values: &mut Vec<Instruction>,
162172
pointer_to_pointee: &FxHashMap<Word, Word>,
163173
constants: &FxHashMap<Word, u32>,
164-
blocks: &mut [Block],
174+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
165175
dominance_frontier: &[FxHashSet<usize>],
166176
) -> bool {
167177
let var_maps_and_types = blocks[0]
@@ -198,7 +208,11 @@ fn insert_phis_all(
198208
rewrite_rules: FxHashMap::default(),
199209
};
200210
renamer.rename(0, None);
201-
apply_rewrite_rules(&renamer.rewrite_rules, blocks);
211+
// FIXME(eddyb) shouldn't this full rescan of the function be done once?
212+
apply_rewrite_rules(
213+
&renamer.rewrite_rules,
214+
blocks.values_mut().map(|block| &mut **block),
215+
);
202216
remove_nops(blocks);
203217
}
204218
remove_old_variables(blocks, &var_maps_and_types);
@@ -216,7 +230,7 @@ struct VarInfo {
216230
fn collect_access_chains(
217231
pointer_to_pointee: &FxHashMap<Word, Word>,
218232
constants: &FxHashMap<Word, u32>,
219-
blocks: &[Block],
233+
blocks: &FxIndexMap<LabelId, &mut Block>,
220234
base_var: Word,
221235
base_var_ty: Word,
222236
) -> Option<FxHashMap<Word, VarInfo>> {
@@ -249,7 +263,7 @@ fn collect_access_chains(
249263
// Loop in case a previous block references a later AccessChain
250264
loop {
251265
let mut changed = false;
252-
for inst in blocks.iter().flat_map(|b| &b.instructions) {
266+
for inst in blocks.values().flat_map(|b| &b.instructions) {
253267
for (index, op) in inst.operands.iter().enumerate() {
254268
if let Operand::IdRef(id) = op {
255269
if variables.contains_key(id) {
@@ -307,10 +321,10 @@ fn collect_access_chains(
307321
// same var map (e.g. `s.x = s.y;`).
308322
fn split_copy_memory(
309323
header: &mut ModuleHeader,
310-
blocks: &mut [Block],
324+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
311325
var_map: &FxHashMap<Word, VarInfo>,
312326
) {
313-
for block in blocks {
327+
for block in blocks.values_mut() {
314328
let mut inst_index = 0;
315329
while inst_index < block.instructions.len() {
316330
let inst = &block.instructions[inst_index];
@@ -369,7 +383,7 @@ fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
369383
}
370384

371385
fn insert_phis(
372-
blocks: &[Block],
386+
blocks: &FxIndexMap<LabelId, &mut Block>,
373387
dominance_frontier: &[FxHashSet<usize>],
374388
var_map: &FxHashMap<Word, VarInfo>,
375389
) -> FxHashSet<usize> {
@@ -378,7 +392,7 @@ fn insert_phis(
378392
let mut ever_on_work_list = FxHashSet::default();
379393
let mut work_list = Vec::new();
380394
let mut blocks_with_phi = FxHashSet::default();
381-
for (block_idx, block) in blocks.iter().enumerate() {
395+
for (block_idx, block) in blocks.values().enumerate() {
382396
if has_store(block, var_map) {
383397
ever_on_work_list.insert(block_idx);
384398
work_list.push(block_idx);
@@ -423,10 +437,10 @@ fn top_stack_or_undef(
423437
}
424438
}
425439

426-
struct Renamer<'a> {
440+
struct Renamer<'a, 'b> {
427441
header: &'a mut ModuleHeader,
428442
types_global_values: &'a mut Vec<Instruction>,
429-
blocks: &'a mut [Block],
443+
blocks: &'a mut FxIndexMap<LabelId, &'b mut Block>,
430444
blocks_with_phi: FxHashSet<usize>,
431445
base_var_type: Word,
432446
var_map: &'a FxHashMap<Word, VarInfo>,
@@ -436,7 +450,7 @@ struct Renamer<'a> {
436450
rewrite_rules: FxHashMap<Word, Word>,
437451
}
438452

439-
impl Renamer<'_> {
453+
impl Renamer<'_, '_> {
440454
// Returns the phi definition.
441455
fn insert_phi_value(&mut self, block: usize, from_block: usize) -> Word {
442456
let from_block_label = self.blocks[from_block].label_id().unwrap();
@@ -558,9 +572,8 @@ impl Renamer<'_> {
558572
}
559573
}
560574

561-
for dest_id in outgoing_edges(&self.blocks[block]).collect::<Vec<_>>() {
562-
// TODO: Don't do this find
563-
let dest_idx = label_to_index(self.blocks, dest_id);
575+
for dest_id in outgoing_edges(self.blocks[block]).collect::<Vec<_>>() {
576+
let dest_idx = self.blocks.get_index_of(&LabelId(dest_id)).unwrap();
564577
self.rename(dest_idx, Some(block));
565578
}
566579

@@ -570,16 +583,16 @@ impl Renamer<'_> {
570583
}
571584
}
572585

573-
fn remove_nops(blocks: &mut [Block]) {
574-
for block in blocks {
586+
fn remove_nops(blocks: &mut FxIndexMap<LabelId, &mut Block>) {
587+
for block in blocks.values_mut() {
575588
block
576589
.instructions
577590
.retain(|inst| inst.class.opcode != Op::Nop);
578591
}
579592
}
580593

581594
fn remove_old_variables(
582-
blocks: &mut [Block],
595+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
583596
var_maps_and_types: &[(FxHashMap<u32, VarInfo>, u32)],
584597
) {
585598
blocks[0].instructions.retain(|inst| {
@@ -590,7 +603,7 @@ fn remove_old_variables(
590603
.all(|(var_map, _)| !var_map.contains_key(&result_id))
591604
}
592605
});
593-
for block in blocks {
606+
for block in blocks.values_mut() {
594607
block.instructions.retain(|inst| {
595608
!matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain)
596609
|| inst.operands.iter().all(|op| {

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ fn id(header: &mut ModuleHeader) -> Word {
8888
result
8989
}
9090

91-
fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Block]) {
91+
fn apply_rewrite_rules<'a>(
92+
rewrite_rules: &FxHashMap<Word, Word>,
93+
blocks: impl IntoIterator<Item = &'a mut Block>,
94+
) {
9295
let apply = |inst: &mut Instruction| {
9396
if let Some(ref mut id) = &mut inst.result_id {
9497
if let Some(&rewrite) = rewrite_rules.get(id) {

0 commit comments

Comments
 (0)