Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 156 additions & 4 deletions crates/rustc_codegen_spirv/src/linker/dce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
//! *references* a rooted thing is also rooted, not the other way around - but that's the basic
//! concept.

use rspirv::dr::{Function, Instruction, Module};
use rspirv::spirv::{Op, Word};
use rspirv::dr::{Function, Instruction, Module, Operand};
use rspirv::spirv::{Op, StorageClass, Word};
use rustc_data_structures::fx::FxHashSet;

pub fn dce(module: &mut Module) {
Expand Down Expand Up @@ -36,8 +36,29 @@ fn spread_roots(module: &Module, rooted: &mut FxHashSet<Word>) -> bool {
}
for func in &module.functions {
if rooted.contains(&func.def_id().unwrap()) {
for inst in func.all_inst_iter() {
any |= root(inst, rooted);
// NB (Mobius 2021) - since later insts are much more likely to reference
// earlier insts, by reversing the iteration order, we're more likely to root the
// entire relevant function at once.
// See https://github.com/EmbarkStudios/rust-gpu/pull/691#discussion_r681477091
for inst in func
.end
.iter()
.chain(
func.blocks
.iter()
.rev()
.flat_map(|b| b.instructions.iter().rev().chain(b.label.iter())),
)
.chain(func.parameters.iter().rev())
.chain(func.def.iter())
{
if !instruction_is_pure(inst) {
any |= root(inst, rooted);
} else if let Some(id) = inst.result_id {
if rooted.contains(&id) {
any |= root(inst, rooted);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hrm, this might be dreadfully inefficient, some profiling might be nice here! (how many times spread_roots is called before/after this change). It's probably way more efficient to iterate over instructions backwards instead of forwards.

Copy link
Contributor

@khyperia khyperia Aug 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, just checked, when running the multibuilder example on main, dce is called three times, and in each of those, spread_roots is called [3, 4, 1] times. With this PR, it's [8, 8, 5], a significant cost (definitely not zero-cost, especially with the last one being a no-op). But, reversing the iteration order of the function's instructions, it drops back down to [3, 4, 1].

It'll still likely be worse than the original in some cases (loops etc.), so it's not practically zero-cost, but still should be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. So do the reversing of iteration and write a comment saying why it is so?

}
}
}
}
}
Expand Down Expand Up @@ -90,6 +111,13 @@ fn kill_unrooted(module: &mut Module, rooted: &FxHashSet<Word>) {
module
.functions
.retain(|f| is_rooted(f.def.as_ref().unwrap(), rooted));
module.functions.iter_mut().for_each(|fun| {
fun.blocks.iter_mut().for_each(|block| {
block
.instructions
.retain(|inst| !instruction_is_pure(inst) || is_rooted(inst, rooted));
});
});
}

pub fn dce_phi(func: &mut Function) {
Expand All @@ -115,3 +143,127 @@ pub fn dce_phi(func: &mut Function) {
.retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()));
}
}

fn instruction_is_pure(inst: &Instruction) -> bool {
use Op::*;
match inst.class.opcode {
Nop
| Undef
| ConstantTrue
| ConstantFalse
| Constant
| ConstantComposite
| ConstantSampler
| ConstantNull
| AccessChain
| InBoundsAccessChain
| PtrAccessChain
| ArrayLength
| InBoundsPtrAccessChain
| CompositeConstruct
| CompositeExtract
| CopyObject
| Transpose
| ConvertFToU
| ConvertFToS
| ConvertSToF
| ConvertUToF
| UConvert
| SConvert
| FConvert
| QuantizeToF16
| ConvertPtrToU
| SatConvertSToU
| SatConvertUToS
| ConvertUToPtr
| PtrCastToGeneric
| GenericCastToPtr
| GenericCastToPtrExplicit
| Bitcast
| SNegate
| FNegate
| IAdd
| FAdd
| ISub
| FSub
| IMul
| FMul
| UDiv
| SDiv
| FDiv
| UMod
| SRem
| SMod
| FRem
| FMod
| VectorTimesScalar
| MatrixTimesScalar
| VectorTimesMatrix
| MatrixTimesVector
| MatrixTimesMatrix
| OuterProduct
| Dot
| IAddCarry
| ISubBorrow
| UMulExtended
| SMulExtended
| Any
| All
| IsNan
| IsInf
| IsFinite
| IsNormal
| SignBitSet
| LessOrGreater
| Ordered
| Unordered
| LogicalEqual
| LogicalNotEqual
| LogicalOr
| LogicalAnd
| LogicalNot
| Select
| IEqual
| INotEqual
| UGreaterThan
| SGreaterThan
| UGreaterThanEqual
| SGreaterThanEqual
| ULessThan
| SLessThan
| ULessThanEqual
| SLessThanEqual
| FOrdEqual
| FUnordEqual
| FOrdNotEqual
| FUnordNotEqual
| FOrdLessThan
| FUnordLessThan
| FOrdGreaterThan
| FUnordGreaterThan
| FOrdLessThanEqual
| FUnordLessThanEqual
| FOrdGreaterThanEqual
| FUnordGreaterThanEqual
| ShiftRightLogical
| ShiftRightArithmetic
| ShiftLeftLogical
| BitwiseOr
| BitwiseXor
| BitwiseAnd
| Not
| BitFieldInsert
| BitFieldSExtract
| BitFieldUExtract
| BitReverse
| BitCount
| Phi
| SizeOf
| CopyLogical
| PtrEqual
| PtrNotEqual
| PtrDiff => true,
Variable => inst.operands.get(0) == Some(&Operand::StorageClass(StorageClass::Function)),
_ => false,
}
}
35 changes: 17 additions & 18 deletions tests/ui/dis/index_user_dst.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,32 @@ OpLine %5 7 12
%10 = OpArrayLength %11 %8 0
OpLine %5 7 0
%12 = OpCompositeInsert %13 %6 %14 0
%15 = OpCompositeConstruct %13 %6 %10
OpLine %5 8 21
%16 = OpULessThan %17 %9 %10
%15 = OpULessThan %16 %9 %10
OpLine %5 8 21
OpSelectionMerge %18 None
OpBranchConditional %16 %19 %20
%19 = OpLabel
OpSelectionMerge %17 None
OpBranchConditional %15 %18 %19
%18 = OpLabel
OpLine %5 8 21
%21 = OpInBoundsAccessChain %22 %6 %9
%23 = OpLoad %24 %21
%20 = OpInBoundsAccessChain %21 %6 %9
%22 = OpLoad %23 %20
OpLine %5 10 1
OpReturn
%20 = OpLabel
%19 = OpLabel
OpLine %5 8 21
OpBranch %24
%24 = OpLabel
OpBranch %25
%25 = OpLabel
OpBranch %26
%26 = OpLabel
%27 = OpPhi %17 %28 %25 %28 %29
OpLoopMerge %30 %29 None
OpBranchConditional %27 %31 %30
%31 = OpLabel
OpBranch %29
%29 = OpLabel
OpBranch %26
%26 = OpPhi %16 %27 %24 %27 %28
OpLoopMerge %29 %28 None
OpBranchConditional %26 %30 %29
%30 = OpLabel
OpBranch %28
%28 = OpLabel
OpBranch %25
%29 = OpLabel
OpUnreachable
%18 = OpLabel
%17 = OpLabel
OpUnreachable
OpFunctionEnd