Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make HugrMut a trait #132

Merged
merged 11 commits into from
Jun 9, 2023
10 changes: 5 additions & 5 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
pub(crate) mod test {
use super::*;
use crate::builder::{
BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder, HugrMutRef,
BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder,
ModuleBuilder, SubContainer,
};
use crate::ops::{
Expand Down Expand Up @@ -611,7 +611,7 @@ pub(crate) mod test {
dataflow_builder.finish_with_outputs([u].into_iter().chain(w))
}

fn build_if_then_else_merge<T: HugrMutRef>(
fn build_if_then_else_merge<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg: &mut CFGBuilder<T>,
const_pred: &ConstID,
unit_const: &ConstID,
Expand All @@ -624,7 +624,7 @@ pub(crate) mod test {
Ok((split, merge))
}

fn build_then_else_merge_from_if<T: HugrMutRef>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do these actually need both Mut and Ref or is the Mut requirement sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They need it because the impl block is also parameterized like that...as I say I think we could probably go through and replace quite a few AsRef + AsMuts with one or the other, but I've not done that thorough an analysis here; consistency is helpful in itself and we can reconsider later if we really want.

fn build_then_else_merge_from_if<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg: &mut CFGBuilder<T>,
unit_const: &ConstID,
split: BasicBlockID,
Expand All @@ -649,7 +649,7 @@ pub(crate) mod test {
}

// Returns loop tail - caller must link header to tail, and provide 0th successor of tail
fn build_loop_from_header<T: HugrMutRef>(
fn build_loop_from_header<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg: &mut CFGBuilder<T>,
const_pred: &ConstID,
header: BasicBlockID,
Expand All @@ -663,7 +663,7 @@ pub(crate) mod test {
}

// Result is header and tail. Caller must provide 0th successor of header (linking to tail), and 0th successor of tail.
fn build_loop<T: HugrMutRef>(
fn build_loop<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg: &mut CFGBuilder<T>,
const_pred: &ConstID,
unit_const: &ConstID,
Expand Down
22 changes: 2 additions & 20 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
use thiserror::Error;

use crate::hugr::{HugrError, HugrMut, Node, ValidationError, Wire};
use crate::hugr::{HugrError, Node, ValidationError, Wire};
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};

use crate::types::LinearType;
Expand Down Expand Up @@ -70,26 +70,8 @@ pub enum BuildError {
CircuitError(#[from] circuit_builder::CircuitBuildError),
}

impl AsMut<HugrMut> for HugrMut {
fn as_mut(&mut self) -> &mut HugrMut {
self
}
}
impl AsRef<HugrMut> for HugrMut {
fn as_ref(&self) -> &HugrMut {
self
}
}

/// Trait allowing treating type as (im)mutable reference to [`HugrMut`]
pub trait HugrMutRef: AsMut<HugrMut> + AsRef<HugrMut> {}
impl HugrMutRef for HugrMut {}
impl HugrMutRef for &mut HugrMut {}

#[cfg(test)]
mod test {

use crate::hugr::HugrMut;
use crate::types::{ClassicType, LinearType, Signature, SimpleType};
use crate::Hugr;

Expand All @@ -112,7 +94,7 @@ mod test {

pub(super) fn build_main(
signature: Signature,
f: impl FnOnce(FunctionBuilder<&mut HugrMut>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
f: impl FnOnce(FunctionBuilder<&mut Hugr>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
) -> Result<Hugr, BuildError> {
let mut module_builder = ModuleBuilder::new();
let f_builder = module_builder.declare_and_def("main", signature)?;
Expand Down
59 changes: 30 additions & 29 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,22 @@ use crate::hugr::HugrMut;
pub trait Container {
/// The container node.
fn container_node(&self) -> Node;
/// The underlying [`HugrMut`] being used to build the HUGR.
fn base(&mut self) -> &mut HugrMut;
/// Immutable reference to HUGR being built.
/// The underlying [`Hugr`] being built
fn hugr_mut(&mut self) -> &mut Hugr;
/// Immutable reference to HUGR being built
fn hugr(&self) -> &Hugr;
/// Add an [`OpType`] as the final child of the container.
fn add_child_op(&mut self, op: impl Into<OpType>) -> Result<Node, BuildError> {
let parent = self.container_node();
Ok(self.base().add_op_with_parent(parent, op)?)
Ok(self.hugr_mut().add_op_with_parent(parent, op)?)
}

/// Adds a non-dataflow edge between two nodes. The kind is given by the operation's [`other_inputs`] or [`other_outputs`]
///
/// [`other_inputs`]: crate::ops::OpTrait::other_inputs
/// [`other_outputs`]: crate::ops::OpTrait::other_outputs
fn add_other_wire(&mut self, src: Node, dst: Node) -> Result<Wire, BuildError> {
let (src_port, _) = self.base().add_other_edge(src, dst)?;
let (src_port, _) = self.hugr_mut().add_other_edge(src, dst)?;
Ok(Wire::new(src, Port::new_outgoing(src_port)))
}

Expand Down Expand Up @@ -156,7 +156,7 @@ pub trait Dataflow: Container {
&mut self,
inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
output_types: TypeRow,
) -> Result<DFGBuilder<&mut HugrMut>, BuildError> {
) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
let (input_types, input_wires): (Vec<SimpleType>, Vec<Wire>) = inputs.into_iter().unzip();
let (dfg_n, _) = add_op_with_wires(
self,
Expand All @@ -166,7 +166,7 @@ pub trait Dataflow: Container {
input_wires,
)?;

DFGBuilder::create_with_io(self.base(), dfg_n, input_types.into(), output_types)
DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, input_types.into(), output_types)
}

/// Return a builder for a [`crate::ops::CFG`] node,
Expand All @@ -183,7 +183,7 @@ pub trait Dataflow: Container {
&mut self,
inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
output_types: TypeRow,
) -> Result<CFGBuilder<&mut HugrMut>, BuildError> {
) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
let (input_types, input_wires): (Vec<SimpleType>, Vec<Wire>) = inputs.into_iter().unzip();

let inputs: TypeRow = input_types.into();
Expand All @@ -196,7 +196,7 @@ pub trait Dataflow: Container {
},
input_wires,
)?;
CFGBuilder::create(self.base(), cfg_node, inputs, output_types)
CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
}

/// Load a static constant and return the local dataflow wire for that constant.
Expand All @@ -208,7 +208,7 @@ pub trait Dataflow: Container {
let cn = cid.node();
let c_out = self.hugr().num_outputs(cn);

self.base().add_ports(cn, Direction::Outgoing, 1);
self.hugr_mut().add_ports(cn, Direction::Outgoing, 1);

let load_n = self.add_dataflow_op(
ops::LoadConstant {
Expand Down Expand Up @@ -257,7 +257,7 @@ pub trait Dataflow: Container {
just_inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
inputs_outputs: impl IntoIterator<Item = (SimpleType, Wire)>,
just_out_types: TypeRow,
) -> Result<TailLoopBuilder<&mut HugrMut>, BuildError> {
) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
let (input_types, mut input_wires): (Vec<SimpleType>, Vec<Wire>) =
just_inputs.into_iter().unzip();
let (rest_types, rest_input_wires): (Vec<SimpleType>, Vec<Wire>) =
Expand All @@ -271,7 +271,7 @@ pub trait Dataflow: Container {
};
let (loop_node, _) = add_op_with_wires(self, tail_loop.clone(), input_wires)?;

TailLoopBuilder::create_with_io(self.base(), loop_node, &tail_loop)
TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
}

/// Return a builder for a [`crate::ops::Conditional`] node.
Expand All @@ -291,7 +291,7 @@ pub trait Dataflow: Container {
(predicate_inputs, predicate_wire): (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
output_types: TypeRow,
) -> Result<ConditionalBuilder<&mut HugrMut>, BuildError> {
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
let mut input_wires = vec![predicate_wire];
let (input_types, rest_input_wires): (Vec<SimpleType>, Vec<Wire>) =
other_inputs.into_iter().unzip();
Expand All @@ -312,7 +312,7 @@ pub trait Dataflow: Container {
)?;

Ok(ConditionalBuilder {
base: self.base(),
base: self.hugr_mut(),
conditional_node: conditional_id.node(),
n_out_wires,
case_nodes: vec![None; n_cases],
Expand Down Expand Up @@ -450,11 +450,11 @@ pub trait Dataflow: Container {
let const_in_port = signature.output.len();
let op_id = self.add_dataflow_op(ops::Call { signature }, input_wires)?;
let src_port: usize = self
.base()
.hugr_mut()
.add_ports(function.node(), Direction::Outgoing, 1)
.collect_vec()[0];

self.base()
self.hugr_mut()
.connect(function.node(), src_port, op_id.node(), const_in_port)?;
Ok(op_id)
}
Expand All @@ -475,7 +475,7 @@ fn add_op_with_wires<T: Dataflow + ?Sized>(

let op: OpType = op.into();
let sig = op.signature();
let op_node = data_builder.base().add_op_before(out, op)?;
let op_node = data_builder.hugr_mut().add_op_before(out, op)?;

wire_up_inputs(inputs, op_node, data_builder, inp)?;

Expand All @@ -498,7 +498,8 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
dst_port,
)?;
}
let op = data_builder.base().hugr().get_optype(op_node);
let base = data_builder.hugr_mut();
let op = base.get_optype(op_node);
let some_df_outputs = !op.signature().output.is_empty();
if !any_local_df_inputs && some_df_outputs {
// If op has no inputs add a StateOrder edge from input to place in
Expand All @@ -516,18 +517,17 @@ fn wire_up<T: Dataflow + ?Sized>(
dst: Node,
dst_port: usize,
) -> Result<bool, BuildError> {
let base = data_builder.base();
let base = data_builder.hugr_mut();
let src_offset = Port::new_outgoing(src_port);

let src_parent = base.hugr().get_parent(src);
let dst_parent = base.hugr().get_parent(dst);
let src_parent = base.get_parent(src);
let dst_parent = base.get_parent(dst);
let local_source = src_parent == dst_parent;

// Non-local value sources require a state edge to an ancestor of dst
if !local_source && get_value_kind(base, src, src_offset) == ValueKind::Classic {
let src_parent = src_parent.expect("Node has no parent");
let Some(src_sibling) =
iter::successors(dst_parent, |&p| base.hugr().get_parent(p))
iter::successors(dst_parent, |&p| base.get_parent(p))
.tuple_windows()
.find_map(|(ancestor, ancestor_parent)| {
(ancestor_parent == src_parent).then_some(ancestor)
Expand All @@ -548,18 +548,19 @@ fn wire_up<T: Dataflow + ?Sized>(
}

// Don't copy linear edges.
if base.hugr().linked_ports(src, src_offset).next().is_some() {
if base.linked_ports(src, src_offset).next().is_some() {
if let ValueKind::Linear(typ) = get_value_kind(base, src, src_offset) {
return Err(BuildError::NoCopyLinear(typ));
}
}

data_builder.base().connect(src, src_port, dst, dst_port)?;
data_builder
.hugr_mut()
.connect(src, src_port, dst, dst_port)?;
Ok(local_source
&& matches!(
data_builder
.base()
.hugr()
.hugr_mut()
.get_optype(dst)
.port_kind(Port::new_incoming(dst_port))
.unwrap(),
Expand All @@ -578,8 +579,8 @@ enum ValueKind {
/// Check the kind of a port is a classical Value and return it
/// Return None if Const kind
/// Panics if port not valid for Op or port is not Const/Value
fn get_value_kind(base: &HugrMut, src: Node, src_offset: Port) -> ValueKind {
let wire_kind = base.hugr().get_optype(src).port_kind(src_offset).unwrap();
fn get_value_kind(base: &Hugr, src: Node, src_offset: Port) -> ValueKind {
let wire_kind = base.get_optype(src).port_kind(src_offset).unwrap();
match wire_kind {
EdgeKind::Const(_) => ValueKind::Const,
EdgeKind::Value(simple_type) => match simple_type {
Expand Down
Loading