Skip to content

Commit 3a460a4

Browse files
committed
Revert hugr-passes changes
1 parent 4f07f9b commit 3a460a4

File tree

8 files changed

+92
-288
lines changed

8 files changed

+92
-288
lines changed

hugr-passes/src/call_graph.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub enum CallGraphNode<N = Node> {
1818
FuncDecl(N),
1919
/// petgraph-node corresponds to a [`FuncDefn`](OpType::FuncDefn) node (specified) in the Hugr
2020
FuncDefn(N),
21-
/// petgraph-node corresponds to the entrypoint node of the hugr, that is not
21+
/// petgraph-node corresponds to the root node of the hugr, that is not
2222
/// a [`FuncDefn`](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
2323
/// either, as such a node could not have outgoing edges, so is not represented in the petgraph.
2424
NonFuncRoot,

hugr-passes/src/const_fold.rs

Lines changed: 18 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,44 +15,38 @@ use hugr_core::{
1515
};
1616
use value_handle::ValueHandle;
1717

18+
use crate::dataflow::{
19+
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
20+
partial_from_const,
21+
};
1822
use crate::dead_code::{DeadCodeElimPass, PreserveNode};
1923
use crate::{ComposablePass, composable::validate_if_test};
20-
use crate::{
21-
IncludeExports,
22-
dataflow::{
23-
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
24-
partial_from_const,
25-
},
26-
};
2724

2825
#[derive(Debug, Clone, Default)]
2926
/// A configuration for the Constant Folding pass.
30-
///
31-
/// Note that by default we assume that only the entrypoint is reachable and
32-
/// only if it is not the module root; see [Self::with_inputs]. Mutation
33-
/// occurs anywhere beneath the entrypoint.
3427
pub struct ConstantFoldPass {
3528
allow_increase_termination: bool,
3629
/// Each outer key Node must be either:
37-
/// - a `FuncDefn` child of the module-root
38-
/// - the entrypoint
30+
/// - a `FuncDefn` child of the root, if the root is a module; or
31+
/// - the root, if the root is not a Module
3932
inputs: HashMap<Node, HashMap<IncomingPort, Value>>,
4033
}
4134

4235
#[derive(Clone, Debug, Error, PartialEq)]
4336
#[non_exhaustive]
4437
/// Errors produced by [`ConstantFoldPass`].
4538
pub enum ConstFoldError {
46-
/// Error raised when inputs are provided for a Node that is neither a dataflow
47-
/// parent, nor a [CFG](OpType::CFG), nor a [Conditional](OpType::Conditional).
39+
/// Error raised when a Node is specified as an entry-point but
40+
/// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor
41+
/// a [Conditional](OpType::Conditional).
4842
#[error("{node} has OpType {op} which cannot be an entry-point")]
4943
InvalidEntryPoint {
5044
/// The node which was specified as an entry-point
5145
node: Node,
5246
/// The `OpType` of the node
5347
op: OpType,
5448
},
55-
/// Inputs were provided for a node that is not in the hugr.
49+
/// The chosen entrypoint is not in the hugr.
5650
#[error("Entry-point {node} is not part of the Hugr")]
5751
MissingEntryPoint {
5852
/// The missing node
@@ -73,25 +67,15 @@ impl ConstantFoldPass {
7367
}
7468

7569
/// Specifies a number of external inputs to an entry point of the Hugr.
76-
/// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` (child of the root);
77-
/// for non-Module-rooted Hugrs, `node` is the [HugrView::entrypoint]. (This is not
70+
/// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` child of the root;
71+
/// or for non-Module-rooted Hugrs, `node` is the root of the Hugr. (This is not
7872
/// enforced, but it must be a container and not a module itself.)
7973
///
8074
/// Multiple calls for the same entry-point combine their values, with later
8175
/// values on the same in-port replacing earlier ones.
8276
///
83-
/// Note that providing empty `inputs` indicates that we must preserve the ability
84-
/// to compute the result of `node` for all possible inputs.
85-
/// * If the entrypoint is the module-root, this method should be called for every
86-
/// [FuncDefn] that is externally callable
87-
/// * Otherwise, i.e. if the entrypoint is not the module-root,
88-
/// * The default is to assume the entrypoint is callable with any inputs;
89-
/// * If `node` is the entrypoint, this method allows to restrict the possible inputs
90-
/// * If `node` is beneath the entrypoint, this merely degrades the analysis. (We
91-
/// will mutate only beneath the entrypoint, but using results of analysing the
92-
/// whole Hugr wrt. the specified/any inputs too).
93-
///
94-
/// [FuncDefn]: hugr_core::ops::FuncDefn
77+
/// Note that if `inputs` is empty, this still marks the node as an entry-point, i.e.
78+
/// we must preserve nodes required to compute its result.
9579
pub fn with_inputs(
9680
mut self,
9781
node: Node,
@@ -113,7 +97,8 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
11397
///
11498
/// # Errors
11599
///
116-
/// [ConstFoldError] if inputs were provided via [`Self::with_inputs`] for an invalid node.
100+
/// [`ConstFoldError::InvalidEntryPoint`] if an entry-point added by [`Self::with_inputs`]
101+
/// was of an invalid [`OpType`]
117102
fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> {
118103
let fresh_node = Node::from(portgraph::NodeIndex::new(
119104
hugr.nodes().max().map_or(0, |n| n.index() + 1),
@@ -199,51 +184,25 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
199184
}
200185
}
201186

202-
const NO_INPUTS: [(IncomingPort, Value); 0] = [];
203-
204187
/// Exhaustively apply constant folding to a HUGR.
205188
/// If the Hugr's entrypoint is its [`Module`], assumes all [`FuncDefn`] children are reachable.
206-
/// Otherwise, assume that the [HugrView::entrypoint] is itself reachable.
207189
///
208190
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
209191
/// [`Module`]: hugr_core::ops::OpType::Module
210-
#[deprecated(note = "Use fold_constants, or manually configure ConstantFoldPass")]
211192
pub fn constant_fold_pass<H: HugrMut<Node = Node> + 'static>(mut h: impl AsMut<H>) {
212193
let h = h.as_mut();
213194
let c = ConstantFoldPass::default();
214195
let c = if h.get_optype(h.entrypoint()).is_module() {
196+
let no_inputs: [(IncomingPort, _); 0] = [];
215197
h.children(h.entrypoint())
216198
.filter(|n| h.get_optype(*n).is_func_defn())
217-
.fold(c, |c, n| c.with_inputs(n, NO_INPUTS.clone()))
199+
.fold(c, |c, n| c.with_inputs(n, no_inputs.iter().cloned()))
218200
} else {
219201
c
220202
};
221203
validate_if_test(c, h).unwrap();
222204
}
223205

224-
/// Exhaustively apply constant folding to a HUGR.
225-
/// Assumes that the Hugr's entrypoint is reachable (if it is not a [`Module`]).
226-
/// Also uses `policy` to determine which public [`FuncDefn`] children of the [`HugrView::module_root`] are reachable.
227-
///
228-
/// [`Module`]: hugr_core::ops::OpType::Module
229-
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
230-
pub fn fold_constants(h: &mut (impl HugrMut<Node = Node> + 'static), policy: IncludeExports) {
231-
let mut funcs = Vec::new();
232-
if !h.entrypoint_optype().is_module() {
233-
funcs.push(h.entrypoint());
234-
}
235-
if policy.for_hugr(&h) {
236-
funcs.extend(
237-
h.children(h.module_root())
238-
.filter(|n| h.get_optype(*n).is_func_defn()),
239-
)
240-
}
241-
let c = funcs.into_iter().fold(ConstantFoldPass::default(), |c, n| {
242-
c.with_inputs(n, NO_INPUTS.clone())
243-
});
244-
validate_if_test(c, h).unwrap();
245-
}
246-
247206
struct ConstFoldContext;
248207

249208
impl ConstLoader<ValueHandle<Node>> for ConstFoldContext {

hugr-passes/src/const_fold/test.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,10 @@ use hugr_core::std_extensions::logic::LogicOp;
2929
use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
3030
use hugr_core::{Hugr, HugrView, IncomingPort, Node, type_row};
3131

32+
use crate::ComposablePass as _;
3233
use crate::dataflow::{DFContext, PartialValue, partial_from_const};
33-
use crate::{ComposablePass as _, IncludeExports};
3434

35-
use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, fold_constants};
36-
37-
fn constant_fold_pass(h: &mut (impl HugrMut<Node = Node> + 'static)) {
38-
fold_constants(h, IncludeExports::Always);
39-
}
35+
use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, constant_fold_pass};
4036

4137
#[rstest]
4238
#[case(ConstInt::new_u(4, 2).unwrap(), true)]

hugr-passes/src/dataflow/datalog.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
116116
} else {
117117
let ep = self.0.entrypoint();
118118
let mut p = in_values.into_iter().peekable();
119-
// We must provide some inputs to the entrypoint so that they are Top rather than Bottom.
119+
// We must provide some inputs to the root so that they are Top rather than Bottom.
120120
// (However, this test will fail for DataflowBlock or Case roots, i.e. if no
121121
// inputs have been provided they will still see Bottom. We could store the "input"
122122
// values for even these nodes in self.1 and then convert to actual Wire values

hugr-passes/src/dead_code.rs

Lines changed: 27 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,32 @@
11
//! Pass for removing dead code, i.e. that computes values that are then discarded
22
33
use hugr_core::hugr::internal::HugrInternals;
4-
use hugr_core::{HugrView, Visibility, hugr::hugrmut::HugrMut, ops::OpType};
4+
use hugr_core::{HugrView, hugr::hugrmut::HugrMut, ops::OpType};
55
use std::convert::Infallible;
66
use std::fmt::{Debug, Formatter};
77
use std::{
88
collections::{HashMap, HashSet, VecDeque},
99
sync::Arc,
1010
};
1111

12-
use crate::{ComposablePass, IncludeExports};
12+
use crate::ComposablePass;
1313

14-
/// Configuration for Dead Code Elimination pass, i.e. which removes nodes
15-
/// beneath the [HugrView::entrypoint] that compute only unneeded values.
14+
/// Configuration for Dead Code Elimination pass
1615
#[derive(Clone)]
1716
pub struct DeadCodeElimPass<H: HugrView> {
1817
/// Nodes that are definitely needed - e.g. `FuncDefns`, but could be anything.
19-
/// [HugrView::entrypoint] is assumed to be needed even if not mentioned here.
18+
/// Hugr Root is assumed to be an entry point even if not mentioned here.
2019
entry_points: Vec<H::Node>,
2120
/// Callback identifying nodes that must be preserved even if their
2221
/// results are not used. Defaults to [`PreserveNode::default_for`].
2322
preserve_callback: Arc<PreserveCallback<H>>,
24-
include_exports: IncludeExports,
2523
}
2624

2725
impl<H: HugrView + 'static> Default for DeadCodeElimPass<H> {
2826
fn default() -> Self {
2927
Self {
3028
entry_points: Default::default(),
3129
preserve_callback: Arc::new(PreserveNode::default_for),
32-
include_exports: IncludeExports::default(),
3330
}
3431
}
3532
}
@@ -42,13 +39,11 @@ impl<H: HugrView> Debug for DeadCodeElimPass<H> {
4239
#[derive(Debug)]
4340
struct DCEDebug<'a, N> {
4441
entry_points: &'a Vec<N>,
45-
include_exports: IncludeExports,
4642
}
4743

4844
Debug::fmt(
4945
&DCEDebug {
5046
entry_points: &self.entry_points,
51-
include_exports: self.include_exports,
5247
},
5348
f,
5449
)
@@ -74,12 +69,12 @@ pub enum PreserveNode {
7469

7570
impl PreserveNode {
7671
/// A conservative default for a given node. Just examines the node's [`OpType`]:
77-
/// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn` for
78-
/// termination, but would also need to check for cycles in the `CallGraph`.)
72+
/// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn`, but would
73+
/// also need to check for cycles in the [`CallGraph`](super::call_graph::CallGraph).)
7974
/// * Assumes all CFGs must be preserved. (One could, for example, allow acyclic
8075
/// CFGs to be removed.)
81-
/// * Assumes all `TailLoops` must be preserved. (One could use some analysis, e.g.
82-
/// dataflow, to allow removal of `TailLoops` with a bounded number of iterations.)
76+
/// * Assumes all `TailLoops` must be preserved. (One could, for example, use dataflow
77+
/// analysis to allow removal of `TailLoops` that never [Continue](hugr_core::ops::TailLoop::CONTINUE_TAG).)
8378
pub fn default_for<H: HugrView>(h: &H, n: H::Node) -> PreserveNode {
8479
match h.get_optype(n) {
8580
OpType::CFG(_) | OpType::TailLoop(_) | OpType::Call(_) => PreserveNode::MustKeep,
@@ -96,33 +91,16 @@ impl<H: HugrView> DeadCodeElimPass<H> {
9691
self
9792
}
9893

99-
/// Mark some nodes as reachable, i.e. so we cannot eliminate any code used to
100-
/// evaluate their results. The [`HugrView::entrypoint`] is assumed to be reachable;
101-
/// if that is the [`HugrView::module_root`], then any public [FuncDefn] and
102-
/// [FuncDecl]s are also considered reachable by default,
103-
/// but this can be change by [`Self::include_module_exports`].
104-
///
105-
/// [FuncDecl]: OpType::FuncDecl
106-
/// [FuncDefn]: OpType::FuncDefn
94+
/// Mark some nodes as entry points to the Hugr, i.e. so we cannot eliminate any code
95+
/// used to evaluate these nodes.
96+
/// [`HugrView::entrypoint`] is assumed to be an entry point;
97+
/// for Module roots the client will want to mark some of the `FuncDefn` children
98+
/// as entry points too.
10799
pub fn with_entry_points(mut self, entry_points: impl IntoIterator<Item = H::Node>) -> Self {
108100
self.entry_points.extend(entry_points);
109101
self
110102
}
111103

112-
/// Sets whether the exported [FuncDefn](OpType::FuncDefn)s and
113-
/// [FuncDecl](OpType::FuncDecl)s are considered reachable.
114-
///
115-
/// Note that for non-module-entry Hugrs this has no effect, since we only remove
116-
/// code beneath the entrypoint: this cannot be affected by other module children.
117-
///
118-
/// So, for module-rooted-Hugrs: [IncludeExports::OnlyIfEntrypointIsModuleRoot] is
119-
/// equivalent to [IncludeExports::Always]; and [IncludeExports::Never] will remove
120-
/// all children, unless some are explicity added by [Self::with_entry_points].
121-
pub fn include_module_exports(mut self, include: IncludeExports) -> Self {
122-
self.include_exports = include;
123-
self
124-
}
125-
126104
fn find_needed_nodes(&self, h: &H) -> HashSet<H::Node> {
127105
let mut must_preserve = HashMap::new();
128106
let mut needed = HashSet::new();
@@ -133,23 +111,19 @@ impl<H: HugrView> DeadCodeElimPass<H> {
133111
continue;
134112
}
135113
for ch in h.children(n) {
136-
let must_keep = match h.get_optype(ch) {
114+
if self.must_preserve(h, &mut must_preserve, ch)
115+
|| matches!(
116+
h.get_optype(ch),
137117
OpType::Case(_) // Include all Cases in Conditionals
138118
| OpType::DataflowBlock(_) // and all Basic Blocks in CFGs
139119
| OpType::ExitBlock(_)
140120
| OpType::AliasDecl(_) // and all Aliases (we do not track their uses in types)
141121
| OpType::AliasDefn(_)
142122
| OpType::Input(_) // Also Dataflow input/output, these are necessary for legality
143-
| OpType::Output(_) => true,
144-
// FuncDefns (as children of Module) only if public and including exports
145-
// (will be included if static predecessors of Call/LoadFunction below,
146-
// regardless of Visibility or self.include_exports)
147-
OpType::FuncDefn(fd) => fd.visibility() == &Visibility::Public && self.include_exports.for_hugr(h),
148-
OpType::FuncDecl(fd) => fd.visibility() == &Visibility::Public && self.include_exports.for_hugr(h),
149-
// No Const, unless reached along static edges
150-
_ => false
151-
};
152-
if must_keep || self.must_preserve(h, &mut must_preserve, ch) {
123+
| OpType::Output(_) // Do not include FuncDecl / FuncDefn / Const unless reachable by static edges
124+
// (from Call/LoadConst/LoadFunction):
125+
)
126+
{
153127
q.push_back(ch);
154128
}
155129
}
@@ -167,6 +141,7 @@ impl<H: HugrView> DeadCodeElimPass<H> {
167141
if let Some(res) = cache.get(&n) {
168142
return *res;
169143
}
144+
#[allow(deprecated)]
170145
let res = match self.preserve_callback.as_ref()(h, n) {
171146
PreserveNode::MustKeep => true,
172147
PreserveNode::CanRemoveIgnoringChildren => false,
@@ -199,57 +174,18 @@ impl<H: HugrMut> ComposablePass<H> for DeadCodeElimPass<H> {
199174
mod test {
200175
use std::sync::Arc;
201176

202-
use hugr_core::builder::{
203-
CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder,
204-
};
177+
use hugr_core::Hugr;
178+
use hugr_core::builder::{CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder};
205179
use hugr_core::extension::prelude::{ConstUsize, usize_t};
206-
use hugr_core::ops::{OpTag, OpTrait, Value, handle::NodeHandle};
207-
use hugr_core::{Hugr, HugrView, type_row, types::Signature};
180+
use hugr_core::ops::{OpTag, OpTrait, handle::NodeHandle};
181+
use hugr_core::types::Signature;
182+
use hugr_core::{HugrView, ops::Value, type_row};
208183
use itertools::Itertools;
209-
use rstest::rstest;
210184

211-
use crate::{ComposablePass, IncludeExports};
185+
use crate::ComposablePass;
212186

213187
use super::{DeadCodeElimPass, PreserveNode};
214188

215-
#[rstest]
216-
#[case(false, IncludeExports::Never, true)]
217-
#[case(false, IncludeExports::OnlyIfEntrypointIsModuleRoot, false)]
218-
#[case(false, IncludeExports::Always, false)]
219-
#[case(true, IncludeExports::Never, true)]
220-
#[case(true, IncludeExports::OnlyIfEntrypointIsModuleRoot, false)]
221-
#[case(true, IncludeExports::Always, false)]
222-
fn test_module_exports(
223-
#[case] include_dfn: bool,
224-
#[case] module_exports: IncludeExports,
225-
#[case] decl_removed: bool,
226-
) {
227-
let mut mb = ModuleBuilder::new();
228-
let dfn = mb
229-
.define_function("foo", Signature::new_endo(usize_t()))
230-
.unwrap();
231-
let ins = dfn.input_wires();
232-
let dfn = dfn.finish_with_outputs(ins).unwrap();
233-
let dcl = mb
234-
.declare("bar", Signature::new_endo(usize_t()).into())
235-
.unwrap();
236-
let mut h = mb.finish_hugr().unwrap();
237-
let mut dce = DeadCodeElimPass::<Hugr>::default().include_module_exports(module_exports);
238-
if include_dfn {
239-
dce = dce.with_entry_points([dfn.node()]);
240-
}
241-
dce.run(&mut h).unwrap();
242-
let defn_retained = include_dfn;
243-
let decl_retained = !decl_removed;
244-
let children = h.children(h.module_root()).collect_vec();
245-
assert_eq!(defn_retained, children.iter().contains(&dfn.node()));
246-
assert_eq!(decl_retained, children.iter().contains(&dcl.node()));
247-
assert_eq!(
248-
children.len(),
249-
(defn_retained as usize) + (decl_retained as usize)
250-
);
251-
}
252-
253189
#[test]
254190
fn test_cfg_callback() {
255191
let mut cb = CFGBuilder::new(Signature::new_endo(type_row![])).unwrap();

0 commit comments

Comments
 (0)