Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions hugr-passes/src/call_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ pub enum CallGraphNode<N = Node> {
}

/// Details the [`Call`]s and [`LoadFunction`]s in a Hugr.
/// Each node in the `CallGraph` corresponds to a [`FuncDefn`] in the Hugr; each edge corresponds
/// to a [`Call`]/[`LoadFunction`] of the edge's target, contained in the edge's source.
///
/// For Hugrs whose root is neither a [Module](OpType::Module) nor a [`FuncDefn`], the call graph
/// will have an additional [`CallGraphNode::NonFuncRoot`] corresponding to the Hugr's root, with no incoming edges.
/// Each node in the `CallGraph` corresponds to a [`FuncDefn`] or [`FuncDecl`] in the Hugr;
/// each edge corresponds to a [`Call`]/[`LoadFunction`] of the edge's target, contained in
/// the edge's source.
///
/// For Hugrs whose entrypoint is neither a [Module](OpType::Module) nor a [`FuncDefn`], the
/// call graph will have an additional [`CallGraphNode::NonFuncRoot`] corresponding to the Hugr's
/// entrypoint, with no incoming edges.
///
/// [`Call`]: OpType::Call
/// [`FuncDecl`]: OpType::FuncDecl
/// [`FuncDefn`]: OpType::FuncDefn
/// [`LoadFunction`]: OpType::LoadFunction
pub struct CallGraph<N = Node> {
Expand All @@ -41,14 +45,13 @@ pub struct CallGraph<N = Node> {
}

impl<N: HugrNode> CallGraph<N> {
/// Makes a new `CallGraph` for a specified (subview) of a Hugr.
/// Calls to functions outside the view will be dropped.
/// Makes a new `CallGraph` for a Hugr.
pub fn new(hugr: &impl HugrView<Node = N>) -> Self {
let mut g = Graph::default();
let non_func_root =
(!hugr.get_optype(hugr.entrypoint()).is_module()).then_some(hugr.entrypoint());
let node_to_g = hugr
.entry_descendants()
.children(hugr.module_root())
.filter_map(|n| {
let weight = match hugr.get_optype(n) {
OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n),
Expand Down Expand Up @@ -94,7 +97,7 @@ impl<N: HugrNode> CallGraph<N> {

/// Convert a Hugr [Node] into a petgraph node index.
/// Result will be `None` if `n` is not a [`FuncDefn`](OpType::FuncDefn),
/// [`FuncDecl`](OpType::FuncDecl) or the hugr root.
/// [`FuncDecl`](OpType::FuncDecl) or the [HugrView::entrypoint].
pub fn node_index(&self, n: N) -> Option<petgraph::graph::NodeIndex<u32>> {
self.node_to_g.get(&n).copied()
}
Expand Down
99 changes: 53 additions & 46 deletions hugr-passes/src/dead_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use super::call_graph::{CallGraph, CallGraphNode};
#[non_exhaustive]
/// Errors produced by [`RemoveDeadFuncsPass`].
pub enum RemoveDeadFuncsError<N = Node> {
/// The specified entry point is not a `FuncDefn` node or is not a child of the root.
/// The specified entry point is not a `FuncDefn` node
#[error(
"Entrypoint for RemoveDeadFuncsPass {node} was not a function definition in the root module"
)]
Expand All @@ -35,30 +35,17 @@ fn reachable_funcs<'a, H: HugrView>(
cg: &'a CallGraph<H::Node>,
h: &'a H,
entry_points: impl IntoIterator<Item = H::Node>,
) -> Result<impl Iterator<Item = H::Node> + 'a, RemoveDeadFuncsError<H::Node>> {
) -> impl Iterator<Item = H::Node> + 'a {
let g = cg.graph();
let mut entry_points = entry_points.into_iter();
let searcher = if h.get_optype(h.entrypoint()).is_module() {
let mut d = Dfs::new(g, 0.into());
d.stack.clear();
for n in entry_points {
if !h.get_optype(n).is_func_defn() || h.get_parent(n) != Some(h.entrypoint()) {
return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
}
d.stack.push(cg.node_index(n).unwrap());
}
d
} else {
if let Some(n) = entry_points.next() {
// Can't be a child of the module root as there isn't a module root!
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment has been bad since #2147 ...

return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
}
Dfs::new(g, cg.node_index(h.entrypoint()).unwrap())
};
Ok(searcher.iter(g).map(|i| match g.node_weight(i).unwrap() {
let mut d = Dfs::new(g, 0.into());
d.stack.clear(); // Remove the fake 0
for n in entry_points {
d.stack.push(cg.node_index(n).unwrap());
}
d.iter(g).map(|i| match g.node_weight(i).unwrap() {
CallGraphNode::FuncDefn(n) | CallGraphNode::FuncDecl(n) => *n,
CallGraphNode::NonFuncRoot => h.entrypoint(),
}))
})
}

#[derive(Debug, Clone, Default)]
Expand Down Expand Up @@ -86,14 +73,31 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
type Error = RemoveDeadFuncsError;
type Result = ();
fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
let reachable = reachable_funcs(
&CallGraph::new(hugr),
hugr,
self.entry_points.iter().copied(),
)?
.collect::<HashSet<_>>();
let mut entry_points = Vec::new();
for &n in self.entry_points.iter() {
if !hugr.get_optype(n).is_func_defn() {
return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
}
debug_assert_eq!(hugr.get_parent(n), Some(hugr.module_root()));
entry_points.push(n);
}
if hugr.entrypoint() != hugr.module_root() {
entry_points.push(hugr.entrypoint())
}

let mut reachable =
reachable_funcs(&CallGraph::new(hugr), hugr, entry_points).collect::<HashSet<_>>();
// Also prevent removing the entrypoint itself
let mut n = Some(hugr.entrypoint());
while let Some(n2) = n {
n = hugr.get_parent(n2);
if n == Some(hugr.module_root()) {
reachable.insert(n2);
}
}

let unreachable = hugr
.entry_descendants()
.children(hugr.module_root())
.filter(|n| {
OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n)
})
Expand All @@ -108,17 +112,13 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
/// Deletes from the Hugr any functions that are not used by either [`Call`] or
/// [`LoadFunction`] nodes in reachable parts.
///
/// For [`Module`]-rooted Hugrs, `entry_points` may provide a list of entry points,
/// which must be children of the root. Note that if `entry_points` is empty, this will
/// result in all functions in the module being removed.
///
/// For non-[`Module`]-rooted Hugrs, `entry_points` must be empty; the root node is used.
/// `entry_points` may provide a list of entry points, which must be [`FuncDefn`]s (children of the root).
/// The [HugrView::entrypoint] will also be used unless it is the [HugrView::module_root].
/// Note that for a [`Module`]-rooted Hugr with no `entry_points` provided, this will remove
/// all functions from the module.
///
/// # Errors
/// * If there are any `entry_points` but the root of the hugr is not a [`Module`]
/// * If any node in `entry_points` is
/// * not a [`FuncDefn`], or
/// * not a child of the root
/// * If any node in `entry_points` is not a [`FuncDefn`]
///
/// [`Call`]: hugr_core::ops::OpType::Call
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
Expand All @@ -138,22 +138,26 @@ pub fn remove_dead_funcs(
mod test {
use std::collections::HashMap;

use hugr_core::ops::handle::NodeHandle;
use itertools::Itertools;
use rstest::rstest;

use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::{HugrView, extension::prelude::usize_t, types::Signature};

use super::remove_dead_funcs;

#[rstest]
#[case([], vec![])] // No entry_points removes everything!
#[case(["main"], vec!["from_main", "main"])]
#[case(["from_main"], vec!["from_main"])]
#[case(["other1"], vec!["other1", "other2"])]
#[case(["other2"], vec!["other2"])]
#[case(["other1", "other2"], vec!["other1", "other2"])]
#[case(false, [], vec![])] // No entry_points removes everything!
#[case(true, [], vec!["from_main", "main"])]
#[case(false, ["main"], vec!["from_main", "main"])]
#[case(false, ["from_main"], vec!["from_main"])]
#[case(false, ["other1"], vec!["other1", "other2"])]
#[case(true, ["other2"], vec!["from_main", "main", "other2"])]
#[case(false, ["other1", "other2"], vec!["other1", "other2"])]
fn remove_dead_funcs_entry_points(
#[case] use_hugr_entrypoint: bool,
#[case] entry_points: impl IntoIterator<Item = &'static str>,
#[case] retained_funcs: Vec<&'static str>,
) -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -171,12 +175,15 @@ mod test {
let fm = fm.finish_with_outputs(f_inp)?;
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
let mc = m.call(fm.handle(), &[], m.input_wires())?;
m.finish_with_outputs(mc.outputs())?;
let m = m.finish_with_outputs(mc.outputs())?;

let mut hugr = hb.finish_hugr()?;
if use_hugr_entrypoint {
hugr.set_entrypoint(m.node());
}

let avail_funcs = hugr
.entry_descendants()
.children(hugr.module_root())
.filter_map(|n| {
hugr.get_optype(n)
.as_func_defn()
Expand Down
Loading