Skip to content
This repository was archived by the owner on Mar 5, 2025. It is now read-only.

refactor: clean up fat.rs #38

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ impl<'c, H> Hash for Emission<'c, H> {
impl<'c, H> Clone for Emission<'c, H> {
fn clone(&self) -> Self {
match self {
Self::FuncDefn(arg0) => Self::FuncDefn(arg0.clone()),
Self::FuncDecl(arg0) => Self::FuncDecl(arg0.clone()),
Self::FuncDefn(arg0) => Self::FuncDefn(*arg0),
Self::FuncDecl(arg0) => Self::FuncDecl(*arg0),
}
}
}
Expand Down Expand Up @@ -357,10 +357,8 @@ impl<'c, H: HugrView> EmitHugr<'c, H> {
/// are not emitted directly, but instead by [hugr::ops::LoadConstant] emission. So
/// [FuncDefn] and [FuncDecl] are the only interesting children.
pub fn emit_module(mut self, node: FatNode<'c, hugr::ops::Module, H>) -> Result<Self> {
println!("emit module");
for c in node.children() {
println!("emit child: {}", &c);
match c.get() {
match c.as_ref() {
OpType::FuncDefn(ref fd) => {
self = self.emit_global(c.into_ot(fd))?;
}
Expand All @@ -384,7 +382,7 @@ impl<'c, H: HugrView> EmitHugr<'c, H> {
mut self,
node: FatNode<'c, FuncDefn, H>,
) -> Result<(Self, EmissionSet<'c, H>)> {
let func = self.module_context.get_func_defn(node.clone())?;
let func = self.module_context.get_func_defn(node)?;
let mut func_ctx = EmitFuncContext::new(self.module_context, func)?;
let ret_rmb = func_ctx.new_row_mail_box(node.signature.body().output.iter(), "ret")?;
ops::emit_dataflow_parent(
Expand Down
6 changes: 3 additions & 3 deletions src/emit/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ pub struct EmitOpArgs<'c, OT, H> {
impl<'c, OT, H> EmitOpArgs<'c, OT, H> {
/// Get the internal [FatNode]
pub fn node(&self) -> FatNode<'c, OT, H> {
self.node.clone()
self.node
}
}

impl<'c, H: HugrView> EmitOpArgs<'c, OpType, H> {
/// Attempt to specialise the internal [FatNode].
pub fn try_into_ot<OT: 'c>(self) -> Result<EmitOpArgs<'c, OT, H>, Self>
pub fn try_into_ot<OT>(self) -> Result<EmitOpArgs<'c, OT, H>, Self>
where
&'c OpType: TryInto<&'c OT>,
for<'a> &'a OpType: TryInto<&'a OT>,
{
let EmitOpArgs {
node,
Expand Down
2 changes: 1 addition & 1 deletion src/emit/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
pub fn node_outs_rmb<OT: 'c>(&mut self, node: FatNode<'c, OT, H>) -> Result<RowMailBox<'c>> {
let r = node
.out_value_types()
.map(|(port, hugr_type)| self.map_wire(node.clone(), port, &hugr_type))
.map(|(port, hugr_type)| self.map_wire(node, port, &hugr_type))
.collect::<Result<RowMailBox>>()?;
debug_assert!(zip_eq(node.out_value_types(), r.get_types())
.all(|((_, t), lt)| self.llvm_type(&t).unwrap() == lt));
Expand Down
32 changes: 14 additions & 18 deletions src/emit/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,9 @@ struct DataflowParentEmitter<'c, 'd, OT, H: HugrView> {
outputs: Option<RowPromise<'c>>,
}

impl<'c, 'd, OT: OpTrait + 'c, H: HugrView> DataflowParentEmitter<'c, 'd, OT, H>
impl<'c, 'd, OT: OpTrait, H: HugrView> DataflowParentEmitter<'c, 'd, OT, H>
where
&'c OpType: TryInto<&'c OT>,
// &'c OpType: TryInto<&'c OT>,
// <&'c OpType as TryInto<&'c OT>>::Error: std::fmt::Debug,
for<'a> &'a OpType: TryInto<&'a OT>,
{
pub fn new(context: &'d mut EmitFuncContext<'c, H>, args: EmitOpArgs<'c, OT, H>) -> Self {
Self {
Expand Down Expand Up @@ -111,10 +109,9 @@ where
}

pub fn emit_children(mut self) -> Result<()> {
use hugr::hugr::views::HierarchyView;
use petgraph::visit::Topo;
let node = self.node.clone();
if !OpTag::DataflowParent.is_superset(OpTrait::tag(node.get())) {
let node = self.node;
if !OpTag::DataflowParent.is_superset(node.tag()) {
Err(anyhow!("Not a dataflow parent"))?
};

Expand All @@ -124,15 +121,15 @@ where
debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len());
debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len());

let region: SiblingGraph = SiblingGraph::try_new(node.hugr(), node.node()).unwrap();
let region: SiblingGraph = node.try_new_hierarchy_view().unwrap();
Topo::new(&region.as_petgraph())
.iter(&region.as_petgraph())
.filter(|x| (*x != node.node()))
.map(|x| node.hugr().fat_optype(x))
.try_for_each(|node| {
let inputs_rmb = self.context.node_ins_rmb(node.clone())?;
let inputs_rmb = self.context.node_ins_rmb(node)?;
let inputs = inputs_rmb.read(self.builder(), [])?;
let outputs = self.context.node_outs_rmb(node.clone())?.promise();
let outputs = self.context.node_outs_rmb(node)?.promise();
self.emit(EmitOpArgs {
node,
inputs,
Expand All @@ -142,17 +139,16 @@ where
}
}

impl<'c, OT: OpTrait + 'c, H: HugrView> EmitOp<'c, OpType, H>
for DataflowParentEmitter<'c, '_, OT, H>
impl<'c, OT: OpTrait, H: HugrView> EmitOp<'c, OpType, H> for DataflowParentEmitter<'c, '_, OT, H>
where
&'c OpType: TryInto<&'c OT>,
for<'a> &'a OpType: TryInto<&'a OT>,
{
fn emit(&mut self, args: EmitOpArgs<'c, OpType, H>) -> Result<()> {
if !OpTag::DataflowChild.is_superset(args.node().tag()) {
Err(anyhow!("Not a dataflow child"))?
};

match args.node().get() {
match args.node().as_ref() {
OpType::Input(_) => {
let i = self.take_input()?;
args.outputs.finish(self.builder(), i)
Expand Down Expand Up @@ -283,12 +279,12 @@ pub fn emit_value<'c, H: HugrView>(
}
}

pub(crate) fn emit_dataflow_parent<'c, OT: OpTrait + 'c, H: HugrView>(
pub(crate) fn emit_dataflow_parent<'c, OT: OpTrait, H: HugrView>(
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, OT, H>,
) -> Result<()>
where
&'c OpType: TryInto<&'c OT>,
for<'a> &'a OpType: TryInto<&'a OT>,
{
DataflowParentEmitter::new(context, args).emit_children()
}
Expand Down Expand Up @@ -347,7 +343,7 @@ fn emit_call<'c, H: HugrView>(
.node
.single_linked_output(args.node.called_function_port())
.unwrap();
let func = match func_node.get() {
let func = match func_node.as_ref() {
OpType::FuncDecl(_) => context.get_func_decl(func_node.try_into_ot().unwrap()),
OpType::FuncDefn(_) => context.get_func_defn(func_node.try_into_ot().unwrap()),
_ => Err(anyhow!("emit_call: Not a Decl or Defn")),
Expand All @@ -371,7 +367,7 @@ fn emit_optype<'c, H: HugrView>(
args: EmitOpArgs<'c, OpType, H>,
) -> Result<()> {
let node = args.node();
match node.get() {
match node.as_ref() {
OpType::MakeTuple(ref mt) => emit_make_tuple(context, args.into_ot(mt)),
OpType::UnpackTuple(ref ut) => emit_unpack_tuple(context, args.into_ot(ut)),
OpType::Tag(ref tag) => emit_tag(context, args.into_ot(tag)),
Expand Down
27 changes: 14 additions & 13 deletions src/emit/ops/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ impl<'c, 'd, H: HugrView> CfgEmitter<'c, 'd, H> {
fn get_block_data<OT: 'c>(
&self,
node: &FatNode<'c, OT, H>,
) -> Result<&(BasicBlock<'c>, RowMailBox<'c>)>
) -> Result<(BasicBlock<'c>, RowMailBox<'c>)>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
self.bbs
.get(&node.clone().generalise())
.get(&node.generalise())
.ok_or(anyhow!("Couldn't get block data for: {}", node.index()))
.cloned()
}

/// Consume the emitter by emitting each child of the node.
Expand All @@ -99,23 +100,23 @@ impl<'c, 'd, H: HugrView> CfgEmitter<'c, 'd, H> {
// dataflowblock node, and then branch to the basic block of that entry
// node.
let inputs = self.take_inputs()?;
let (entry_bb, inputs_rmb) = self.get_block_data(&self.entry_node).cloned()?;
let (entry_bb, inputs_rmb) = self.get_block_data(&self.entry_node)?;
let builder = self.context.builder();
inputs_rmb.write(builder, inputs)?;
builder.build_unconditional_branch(entry_bb)?;

// emit each child by delegating to the `impl EmitOp<_>` of self.
for c in self.node.children() {
for child_node in self.node.children() {
let (inputs, outputs) = (vec![], RowMailBox::new_empty().promise());
self.emit(EmitOpArgs {
node: c.clone(),
node: child_node,
inputs,
outputs,
})?;
}

// move the builder to the end of the exit block
let (exit_bb, _) = self.get_block_data(&self.exit_node).cloned()?;
let (exit_bb, _) = self.get_block_data(&self.exit_node)?;
self.context.builder().position_at_end(exit_bb);
Ok(())
}
Expand All @@ -140,14 +141,14 @@ impl<'c, H: HugrView> EmitOp<'c, DataflowBlock, H> for CfgEmitter<'c, '_, H> {
}: EmitOpArgs<'c, DataflowBlock, H>,
) -> Result<()> {
// our entry basic block and our input RowMailBox
let (bb, inputs_rmb) = self.bbs.get(&node.clone().generalise()).unwrap();
let (bb, inputs_rmb) = self.get_block_data(&node)?;
// the basic block and mailbox of each of our successors
let successor_data = node
.output_neighbours()
.map(|succ| self.get_block_data(&succ).cloned())
.map(|succ| self.get_block_data(&succ))
.collect::<Result<Vec<_>>>()?;

self.context.build_positioned(*bb, |context| {
self.context.build_positioned(bb, |context| {
let (_, o) = node.get_io().unwrap();
// get the rowmailbox for our output node
let outputs_rmb = context.node_ins_rmb(o)?;
Expand All @@ -158,7 +159,7 @@ impl<'c, H: HugrView> EmitOp<'c, DataflowBlock, H> for CfgEmitter<'c, '_, H> {
emit_dataflow_parent(
context,
EmitOpArgs {
node: node.clone(),
node,
inputs,
outputs: outputs_rmb.promise(),
},
Expand All @@ -179,7 +180,7 @@ impl<'c, H: HugrView> EmitOp<'c, DataflowBlock, H> for CfgEmitter<'c, '_, H> {
.into_iter()
.enumerate()
.map(|(tag, (target_bb, target_rmb))| {
let bb = context.build_positioned_new_block("", Some(*bb), |context, bb| {
let bb = context.build_positioned_new_block("", Some(bb), |context, bb| {
let builder = context.builder();
let mut vals =
llvm_sum_type.build_untag(builder, tag as u32, outputs[0])?;
Expand All @@ -206,8 +207,8 @@ impl<'c, H: HugrView> EmitOp<'c, DataflowBlock, H> for CfgEmitter<'c, '_, H> {
impl<'c, H: HugrView> EmitOp<'c, ExitBlock, H> for CfgEmitter<'c, '_, H> {
fn emit(&mut self, args: EmitOpArgs<'c, ExitBlock, H>) -> Result<()> {
let outputs = self.take_outputs()?;
let (bb, inputs_rmb) = self.bbs.get(&args.node().generalise()).unwrap();
self.context.build_positioned(*bb, |context| {
let (bb, inputs_rmb) = self.get_block_data(&args.node())?;
self.context.build_positioned(bb, |context| {
let builder = context.builder();
outputs.finish(builder, inputs_rmb.read_vec(builder, [])?)
})
Expand Down
Loading