Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 67 additions & 86 deletions hugr-passes/src/normalize_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ pub enum NormalizeCFGResult<N = Node> {
CFGToDFG,
/// The CFG was preserved, but the entry or exit blocks may have changed.
CFGPreserved {
/// If `Some`, the new [DFG] containing what was previously in the entry block
entry_dfg: Option<N>,
/// Nodes that were in the entry block but have been moved to be siblings of the CFG.
/// (Either empty, or all the nodes in the entry block except [Input]/[Output].)
entry_nodes_moved: Vec<N>,
/// If `Some`, the new [DFG] of what was previously in the last block before the exit
exit_dfg: Option<N>,
/// The number of basic blocks merged together.
Expand Down Expand Up @@ -186,7 +187,7 @@ pub fn normalize_cfg<H: HugrMut>(
// However, we only do this if the Entry block has just one successor (i.e. we can remove
// the entry block altogether) - an extension would be to do this in other cases, preserving
// the Entry block as an empty branch.
let mut entry_dfg = None;
let mut entry_nodes_moved = Vec::new();
if let Some(succ) = h
.output_neighbours(entry)
.exactly_one()
Expand Down Expand Up @@ -217,51 +218,59 @@ pub fn normalize_cfg<H: HugrMut>(
unpack_before_output(h, h.get_io(cfg_node).unwrap()[1], result_tys);
return Ok(NormalizeCFGResult::CFGToDFG);
}
// 1b. Move entry block outside/before the CFG into a DFG; its successor becomes the entry block.
// 1b. Move entry block outside/before the CFG; its successor becomes the entry block.
let new_cfg_inputs = entry_blk.successor_input(0).unwrap();
// Look for nonlocal `Dom` edges from the entry block. (Ignore `Ext` edges.)
let dests = h.children(entry).flat_map(|n| h.output_neighbours(n));
let has_dom_outs = dests.dedup().any(|succ| {
ancestor_block(h, succ).expect("Dom edges within entry, Ext within CFG") != entry
});
if !has_dom_outs {
// Move entry block contents into DFG.
let dfg = h.add_node_with_parent(
cfg_parent,
DFG {
signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()),
},
);
let [_, entry_output] = h.get_io(entry).unwrap();
while let Some(n) = h.first_child(entry) {
h.set_parent(n, dfg);
}
h.move_before_sibling(succ, entry);
h.remove_node(entry);
let nonlocal_srcs = h
.children(entry)
.filter(|n| {
h.output_neighbours(*n).any(|succ| {
ancestor_block(h, succ).expect("Dom edges within entry, Ext within CFG")
!= entry
})
})
.collect::<Vec<_>>();
// Move entry block contents into DFG.
let dfg = h.add_node_with_parent(
cfg_parent,
DFG {
signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()),
},
);
let [_, entry_output] = h.get_io(entry).unwrap();
while let Some(n) = h.first_child(entry) {
h.set_parent(n, dfg);
}
h.move_before_sibling(succ, entry);
h.remove_node(entry);

unpack_before_output(h, entry_output, new_cfg_inputs.clone());
unpack_before_output(h, entry_output, new_cfg_inputs.clone());

// Inputs to CFG go directly to DFG
for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() {
for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() {
h.connect(src.0, src.1, dfg, inp.index());
}
h.disconnect(cfg_node, inp);
// Inputs to CFG go directly to DFG
for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() {
for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() {
h.connect(src.0, src.1, dfg, inp.index());
}
h.disconnect(cfg_node, inp);
}

// Update input ports
let cfg_ty = cfg_ty_mut(h, cfg_node);
let inputs_to_add =
new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize;
cfg_ty.signature.input = new_cfg_inputs;
h.add_ports(cfg_node, Direction::Incoming, inputs_to_add);
// Update input ports
let cfg_ty = cfg_ty_mut(h, cfg_node);
let inputs_to_add = new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize;
cfg_ty.signature.input = new_cfg_inputs;
h.add_ports(cfg_node, Direction::Incoming, inputs_to_add);

// Wire outputs of DFG directly to CFG
for src in h.node_outputs(dfg).collect::<Vec<_>>() {
h.connect(dfg, src, cfg_node, src.index());
}
entry_dfg = Some(dfg);
// Wire outputs of DFG directly to CFG
for src in h.node_outputs(dfg).collect::<Vec<_>>() {
h.connect(dfg, src, cfg_node, src.index());
}
// Inline DFG to ensure that any nonlocal (`Dom`) edges from it, become valid `Ext` edges
for n in nonlocal_srcs {
// With required Order edge. (Do this before inlining, in case n is Input.)
h.add_other_edge(n, cfg_node);
}
entry_nodes_moved.extend(h.children(dfg).skip(2)); // Skip Input/Output nodes
h.apply_patch(InlineDFG(dfg.into())).unwrap();
}
// 2. If the exit node has a single predecessor and that predecessor has no other successors...
let mut exit_dfg = None;
Expand Down Expand Up @@ -323,7 +332,7 @@ pub fn normalize_cfg<H: HugrMut>(
exit_dfg = Some(dfg);
}
Ok(NormalizeCFGResult::CFGPreserved {
entry_dfg,
entry_nodes_moved,
exit_dfg,
num_merged,
})
Expand Down Expand Up @@ -776,13 +785,14 @@ mod test {
let res = normalize_cfg(&mut h).unwrap();
h.validate().unwrap();
let NormalizeCFGResult::CFGPreserved {
entry_dfg: Some(dfg),
entry_nodes_moved,
exit_dfg: None,
num_merged: 0,
} = res
else {
panic!("Unexpected result");
};
assert_eq!(entry_nodes_moved.len(), 4); // Noop, Const, LoadConstant, UnpackTuple
assert_eq!(
h.children(h.entrypoint())
.map(|n| h.get_optype(n).tag())
Expand All @@ -793,20 +803,8 @@ mod test {
let func_children = child_tags_ext_ids(&h, func);
assert_eq!(
func_children.into_iter().sorted().collect_vec(),
["Cfg", "Dfg", "Input", "Output",]
);
assert_eq!(
h.children(func)
.filter(|n| h.get_optype(*n).is_dfg())
.collect_vec(),
[dfg]
);
assert_eq!(
child_tags_ext_ids(&h, dfg)
.into_iter()
.sorted()
.collect_vec(),
[
"Cfg",
"Const",
"Input",
"LoadConst",
Expand Down Expand Up @@ -849,13 +847,14 @@ mod test {
let res = normalize_cfg(&mut h).unwrap();
h.validate().unwrap();
let NormalizeCFGResult::CFGPreserved {
entry_dfg: None,
entry_nodes_moved,
exit_dfg: Some(dfg),
num_merged: 0,
} = res
else {
panic!("Unexpected result");
};
assert_eq!(entry_nodes_moved, []);
assert_eq!(
h.children(h.entrypoint())
.map(|n| h.get_optype(n).tag())
Expand Down Expand Up @@ -963,65 +962,47 @@ mod test {
assert_eq!(h.get_parent(tail_pred.node()), Some(tail_b.node()));

let mut res = NormalizeCFGPass::default().run(&mut h).unwrap();

h.validate().unwrap();
assert_eq!(
res.remove(&inner.node()),
Some(NormalizeCFGResult::CFGToDFG)
);
let Some(NormalizeCFGResult::CFGPreserved {
entry_dfg,
entry_nodes_moved,
exit_dfg: Some(tail_dfg),
num_merged: 0,
}) = res.remove(&h.entrypoint())
else {
panic!("Unexpected result")
};

assert!(res.is_empty());

assert_eq!(entry_nodes_moved.len(), 3);
// Now contains only one CFG with one BB (self-loop)
assert_eq!(
h.nodes()
.filter(|n| h.get_optype(*n).is_cfg())
.collect_vec(),
vec![h.entrypoint()]
.exactly_one()
.ok(),
Some(h.entrypoint())
);
let [loop_, exit] = if nonlocal {
let [entry, exit, loop_] = h.children(h.entrypoint()).collect_array().unwrap();
assert_eq!(h.get_parent(entry_pred.node()), Some(entry));
[loop_, exit]
} else {
h.children(h.entrypoint()).collect_array().unwrap()
};

assert_eq!(h.output_neighbours(loop_).collect_vec(), [loop_, exit]);

let [entry, exit] = h.children(h.entrypoint()).collect_array().unwrap();
assert_eq!(h.output_neighbours(entry).collect_vec(), [entry, exit]);
// Inner CFG is now a DFG (and still sibling of entry_pred)...
assert_eq!(h.get_parent(inner_pred), Some(inner.node()));
assert_eq!(h.get_optype(inner.node()).tag(), OpTag::Dfg);
assert_eq!(h.get_parent(inner.node()), h.get_parent(entry_pred.node()));

// Predicates lifted appropriately...
let func = h.get_parent(h.entrypoint()).unwrap();
assert_eq!(h.get_parent(entry_pred.node()), Some(func));

assert_eq!(h.get_parent(tail_pred.node()), Some(tail_dfg));
assert_eq!(h.get_optype(tail_dfg).tag(), OpTag::Dfg);
assert_eq!(h.get_parent(tail_dfg), Some(func));
let lifted_preds = if nonlocal {
assert!(entry_dfg.is_none());
// entry_pred not lifted, still connected to output
let [output] = h
.output_neighbours(entry_pred.node())
.collect_array()
.unwrap();
assert_eq!(h.get_optype(output).tag(), OpTag::Output);
vec![inner_pred.node(), tail_pred.node()]
} else {
assert_eq!(h.get_parent(entry_dfg.unwrap()), Some(func));
assert_eq!(h.get_parent(entry_pred.node()), entry_dfg);
vec![inner_pred.node(), entry_pred.node(), tail_pred.node()]
};

// ...and followed by UnpackTuple's
for n in lifted_preds {
for n in [inner_pred, entry_pred.node(), tail_pred.node()] {
let [unpack] = h.output_neighbours(n).collect_array().unwrap();
assert!(
h.get_optype(unpack)
Expand Down
Loading