1
1
//! Pass for removing dead code, i.e. that computes values that are then discarded
2
2
3
3
use hugr_core:: hugr:: internal:: HugrInternals ;
4
- use hugr_core:: { HugrView , Visibility , hugr:: hugrmut:: HugrMut , ops:: OpType } ;
4
+ use hugr_core:: { HugrView , hugr:: hugrmut:: HugrMut , ops:: OpType } ;
5
5
use std:: convert:: Infallible ;
6
6
use std:: fmt:: { Debug , Formatter } ;
7
7
use std:: {
8
8
collections:: { HashMap , HashSet , VecDeque } ,
9
9
sync:: Arc ,
10
10
} ;
11
11
12
- use crate :: { ComposablePass , IncludeExports } ;
12
+ use crate :: ComposablePass ;
13
13
14
- /// Configuration for Dead Code Elimination pass, i.e. which removes nodes
15
- /// beneath the [HugrView::entrypoint] that compute only unneeded values.
14
+ /// Configuration for Dead Code Elimination pass
16
15
#[ derive( Clone ) ]
17
16
pub struct DeadCodeElimPass < H : HugrView > {
18
17
/// Nodes that are definitely needed - e.g. `FuncDefns`, but could be anything.
19
- /// [HugrView::entrypoint] is assumed to be needed even if not mentioned here.
18
+ /// Hugr Root is assumed to be an entry point even if not mentioned here.
20
19
entry_points : Vec < H :: Node > ,
21
20
/// Callback identifying nodes that must be preserved even if their
22
21
/// results are not used. Defaults to [`PreserveNode::default_for`].
23
22
preserve_callback : Arc < PreserveCallback < H > > ,
24
- include_exports : IncludeExports ,
25
23
}
26
24
27
25
impl < H : HugrView + ' static > Default for DeadCodeElimPass < H > {
28
26
fn default ( ) -> Self {
29
27
Self {
30
28
entry_points : Default :: default ( ) ,
31
29
preserve_callback : Arc :: new ( PreserveNode :: default_for) ,
32
- include_exports : IncludeExports :: default ( ) ,
33
30
}
34
31
}
35
32
}
@@ -42,13 +39,11 @@ impl<H: HugrView> Debug for DeadCodeElimPass<H> {
42
39
#[ derive( Debug ) ]
43
40
struct DCEDebug < ' a , N > {
44
41
entry_points : & ' a Vec < N > ,
45
- include_exports : IncludeExports ,
46
42
}
47
43
48
44
Debug :: fmt (
49
45
& DCEDebug {
50
46
entry_points : & self . entry_points ,
51
- include_exports : self . include_exports ,
52
47
} ,
53
48
f,
54
49
)
@@ -74,12 +69,12 @@ pub enum PreserveNode {
74
69
75
70
impl PreserveNode {
76
71
/// A conservative default for a given node. Just examines the node's [`OpType`]:
77
- /// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn` for
78
- /// termination, but would also need to check for cycles in the `CallGraph`.)
72
+ /// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn`, but would
73
+ /// also need to check for cycles in the [ `CallGraph`](super::call_graph::CallGraph) .)
79
74
/// * Assumes all CFGs must be preserved. (One could, for example, allow acyclic
80
75
/// CFGs to be removed.)
81
- /// * Assumes all `TailLoops` must be preserved. (One could use some analysis, e.g.
82
- /// dataflow, to allow removal of `TailLoops` with a bounded number of iterations .)
76
+ /// * Assumes all `TailLoops` must be preserved. (One could, for example, use dataflow
77
+ /// analysis to allow removal of `TailLoops` that never [Continue](hugr_core::ops::TailLoop::CONTINUE_TAG) .)
83
78
pub fn default_for < H : HugrView > ( h : & H , n : H :: Node ) -> PreserveNode {
84
79
match h. get_optype ( n) {
85
80
OpType :: CFG ( _) | OpType :: TailLoop ( _) | OpType :: Call ( _) => PreserveNode :: MustKeep ,
@@ -96,33 +91,16 @@ impl<H: HugrView> DeadCodeElimPass<H> {
96
91
self
97
92
}
98
93
99
- /// Mark some nodes as reachable, i.e. so we cannot eliminate any code used to
100
- /// evaluate their results. The [`HugrView::entrypoint`] is assumed to be reachable;
101
- /// if that is the [`HugrView::module_root`], then any public [FuncDefn] and
102
- /// [FuncDecl]s are also considered reachable by default,
103
- /// but this can be change by [`Self::include_module_exports`].
104
- ///
105
- /// [FuncDecl]: OpType::FuncDecl
106
- /// [FuncDefn]: OpType::FuncDefn
94
+ /// Mark some nodes as entry points to the Hugr, i.e. so we cannot eliminate any code
95
+ /// used to evaluate these nodes.
96
+ /// [`HugrView::entrypoint`] is assumed to be an entry point;
97
+ /// for Module roots the client will want to mark some of the `FuncDefn` children
98
+ /// as entry points too.
107
99
pub fn with_entry_points ( mut self , entry_points : impl IntoIterator < Item = H :: Node > ) -> Self {
108
100
self . entry_points . extend ( entry_points) ;
109
101
self
110
102
}
111
103
112
- /// Sets whether the exported [FuncDefn](OpType::FuncDefn)s and
113
- /// [FuncDecl](OpType::FuncDecl)s are considered reachable.
114
- ///
115
- /// Note that for non-module-entry Hugrs this has no effect, since we only remove
116
- /// code beneath the entrypoint: this cannot be affected by other module children.
117
- ///
118
- /// So, for module-rooted-Hugrs: [IncludeExports::OnlyIfEntrypointIsModuleRoot] is
119
- /// equivalent to [IncludeExports::Always]; and [IncludeExports::Never] will remove
120
- /// all children, unless some are explicity added by [Self::with_entry_points].
121
- pub fn include_module_exports ( mut self , include : IncludeExports ) -> Self {
122
- self . include_exports = include;
123
- self
124
- }
125
-
126
104
fn find_needed_nodes ( & self , h : & H ) -> HashSet < H :: Node > {
127
105
let mut must_preserve = HashMap :: new ( ) ;
128
106
let mut needed = HashSet :: new ( ) ;
@@ -133,23 +111,19 @@ impl<H: HugrView> DeadCodeElimPass<H> {
133
111
continue ;
134
112
}
135
113
for ch in h. children ( n) {
136
- let must_keep = match h. get_optype ( ch) {
114
+ if self . must_preserve ( h, & mut must_preserve, ch)
115
+ || matches ! (
116
+ h. get_optype( ch) ,
137
117
OpType :: Case ( _) // Include all Cases in Conditionals
138
118
| OpType :: DataflowBlock ( _) // and all Basic Blocks in CFGs
139
119
| OpType :: ExitBlock ( _)
140
120
| OpType :: AliasDecl ( _) // and all Aliases (we do not track their uses in types)
141
121
| OpType :: AliasDefn ( _)
142
122
| OpType :: Input ( _) // Also Dataflow input/output, these are necessary for legality
143
- | OpType :: Output ( _) => true ,
144
- // FuncDefns (as children of Module) only if public and including exports
145
- // (will be included if static predecessors of Call/LoadFunction below,
146
- // regardless of Visibility or self.include_exports)
147
- OpType :: FuncDefn ( fd) => fd. visibility ( ) == & Visibility :: Public && self . include_exports . for_hugr ( h) ,
148
- OpType :: FuncDecl ( fd) => fd. visibility ( ) == & Visibility :: Public && self . include_exports . for_hugr ( h) ,
149
- // No Const, unless reached along static edges
150
- _ => false
151
- } ;
152
- if must_keep || self . must_preserve ( h, & mut must_preserve, ch) {
123
+ | OpType :: Output ( _) // Do not include FuncDecl / FuncDefn / Const unless reachable by static edges
124
+ // (from Call/LoadConst/LoadFunction):
125
+ )
126
+ {
153
127
q. push_back ( ch) ;
154
128
}
155
129
}
@@ -167,6 +141,7 @@ impl<H: HugrView> DeadCodeElimPass<H> {
167
141
if let Some ( res) = cache. get ( & n) {
168
142
return * res;
169
143
}
144
+ #[ allow( deprecated) ]
170
145
let res = match self . preserve_callback . as_ref ( ) ( h, n) {
171
146
PreserveNode :: MustKeep => true ,
172
147
PreserveNode :: CanRemoveIgnoringChildren => false ,
@@ -199,57 +174,18 @@ impl<H: HugrMut> ComposablePass<H> for DeadCodeElimPass<H> {
199
174
mod test {
200
175
use std:: sync:: Arc ;
201
176
202
- use hugr_core:: builder:: {
203
- CFGBuilder , Container , Dataflow , DataflowSubContainer , HugrBuilder , ModuleBuilder ,
204
- } ;
177
+ use hugr_core:: Hugr ;
178
+ use hugr_core:: builder:: { CFGBuilder , Container , Dataflow , DataflowSubContainer , HugrBuilder } ;
205
179
use hugr_core:: extension:: prelude:: { ConstUsize , usize_t} ;
206
- use hugr_core:: ops:: { OpTag , OpTrait , Value , handle:: NodeHandle } ;
207
- use hugr_core:: { Hugr , HugrView , type_row, types:: Signature } ;
180
+ use hugr_core:: ops:: { OpTag , OpTrait , handle:: NodeHandle } ;
181
+ use hugr_core:: types:: Signature ;
182
+ use hugr_core:: { HugrView , ops:: Value , type_row} ;
208
183
use itertools:: Itertools ;
209
- use rstest:: rstest;
210
184
211
- use crate :: { ComposablePass , IncludeExports } ;
185
+ use crate :: ComposablePass ;
212
186
213
187
use super :: { DeadCodeElimPass , PreserveNode } ;
214
188
215
- #[ rstest]
216
- #[ case( false , IncludeExports :: Never , true ) ]
217
- #[ case( false , IncludeExports :: OnlyIfEntrypointIsModuleRoot , false ) ]
218
- #[ case( false , IncludeExports :: Always , false ) ]
219
- #[ case( true , IncludeExports :: Never , true ) ]
220
- #[ case( true , IncludeExports :: OnlyIfEntrypointIsModuleRoot , false ) ]
221
- #[ case( true , IncludeExports :: Always , false ) ]
222
- fn test_module_exports (
223
- #[ case] include_dfn : bool ,
224
- #[ case] module_exports : IncludeExports ,
225
- #[ case] decl_removed : bool ,
226
- ) {
227
- let mut mb = ModuleBuilder :: new ( ) ;
228
- let dfn = mb
229
- . define_function ( "foo" , Signature :: new_endo ( usize_t ( ) ) )
230
- . unwrap ( ) ;
231
- let ins = dfn. input_wires ( ) ;
232
- let dfn = dfn. finish_with_outputs ( ins) . unwrap ( ) ;
233
- let dcl = mb
234
- . declare ( "bar" , Signature :: new_endo ( usize_t ( ) ) . into ( ) )
235
- . unwrap ( ) ;
236
- let mut h = mb. finish_hugr ( ) . unwrap ( ) ;
237
- let mut dce = DeadCodeElimPass :: < Hugr > :: default ( ) . include_module_exports ( module_exports) ;
238
- if include_dfn {
239
- dce = dce. with_entry_points ( [ dfn. node ( ) ] ) ;
240
- }
241
- dce. run ( & mut h) . unwrap ( ) ;
242
- let defn_retained = include_dfn;
243
- let decl_retained = !decl_removed;
244
- let children = h. children ( h. module_root ( ) ) . collect_vec ( ) ;
245
- assert_eq ! ( defn_retained, children. iter( ) . contains( & dfn. node( ) ) ) ;
246
- assert_eq ! ( decl_retained, children. iter( ) . contains( & dcl. node( ) ) ) ;
247
- assert_eq ! (
248
- children. len( ) ,
249
- ( defn_retained as usize ) + ( decl_retained as usize )
250
- ) ;
251
- }
252
-
253
189
#[ test]
254
190
fn test_cfg_callback ( ) {
255
191
let mut cb = CFGBuilder :: new ( Signature :: new_endo ( type_row ! [ ] ) ) . unwrap ( ) ;
0 commit comments