Skip to content

Commit

Permalink
add copy nodes and state edges for inter-graph
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed May 9, 2023
1 parent 3d149ce commit 2527c16
Showing 1 changed file with 131 additions and 35 deletions.
166 changes: 131 additions & 35 deletions src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::collections::HashSet;
use std::iter;
use std::marker::PhantomData;

use itertools::Itertools;
use portgraph::{Direction, NodeIndex, PortOffset};
use smol_str::SmolStr;
use thiserror::Error;

use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::{HugrMut, ValidationError};
use crate::ops::controlflow::ControlFlowOp;
use crate::ops::{BasicBlockOp, BranchOp, ConstValue, DataflowOp, LeafOp, ModuleOp};
Expand Down Expand Up @@ -100,6 +102,10 @@ pub trait Dataflow: Container {
(self.io()[0], self.num_inputs()).into()
}

fn output(&self) -> OpID {
(self.io()[1], 0).into()
}

fn input_wires(&self) -> Vec<Wire> {
self.input().outputs()
}
Expand Down Expand Up @@ -400,56 +406,77 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
data_builder: &mut T,
inp: NodeIndex,
) -> Result<(), BuildError> {
let no_inputs = inputs.is_empty();
let mut any_local_inputs = false;
for (dst_port, Wire(src, src_port)) in inputs.into_iter().enumerate() {
wire_up(data_builder, src_port, src, op_node, dst_port)?;
any_local_inputs = wire_up(data_builder, src, src_port, op_node, dst_port)?;
}

if no_inputs {
if !any_local_inputs {
// If op has no inputs add a StateOrder edge from input to place in
// causal cone of Input node
data_builder.add_other_wire(inp, op_node)?;
};
Ok(())
}

/// Add edge from src to dst and report back if they do share a parent
fn wire_up<T: Dataflow + ?Sized>(
data_builder: &mut T,
mut src_port: usize,
mut src: NodeIndex,
op_node: NodeIndex,
mut src_port: usize,
dst: NodeIndex,
dst_port: usize,
) -> Result<(), BuildError> {
) -> Result<bool, BuildError> {
let base = data_builder.base();
let src_offset = PortOffset::new_outgoing(src_port);

if let Some((connected, connected_offset)) = base.hugr().linked_port(src, src_offset) {
let connected_op: Result<&LeafOp, ()> = base.hugr().get_optype(connected).try_into();
if let Ok(LeafOp::Copy { n_copies, typ }) = connected_op {
let copy_node = connected;
// If already connected to a copy node, add wire to the copy
let n_copies = *n_copies;
base.replace_op(
copy_node,
LeafOp::Copy {
n_copies: n_copies + 1,
typ: typ.clone(),
},
);
src_port = base
.add_ports(copy_node, Direction::Outgoing, 1)
.next()
.unwrap();
let src_parent = base.hugr().get_parent(src);
let dst_parent = base.hugr().get_parent(dst);
let local_source = src_parent == dst_parent;
if !local_source {
if let Some(copy_port) = if_copy_add_port(base, src) {
src_port = copy_port;
} else if let Some(typ) = check_classical_value(base, src, src_offset)? {
let src_parent = base.hugr().get_parent(src).expect("Node has no parent");

let final_child = base
.hugr()
.children(src_parent)
.next_back()
.expect("Parent must have at least one child.");
let copy_node = base.add_op_before(final_child, LeafOp::Copy { n_copies: 1, typ })?;

base.connect(src, src_port, copy_node, 0)?;

// Copy node has to have state edge to an ancestor of dst
let Some(src_sibling) = iter::successors(dst_parent, |&p| base.hugr().get_parent(p))
.tuple_windows()
.find_map(|(ancestor, ancestor_parent)| {
(ancestor_parent == src_parent).then_some(ancestor)
}) else {
let val_err: ValidationError = InterGraphEdgeError::NoRelation {
from: src,
from_offset: PortOffset::new_outgoing(src_port),
to: dst,
to_offset: PortOffset::new_incoming(dst_port),
}.into();
return Err(val_err.into());
};

base.add_other_edge(copy_node, src_sibling)?;

src = copy_node;
} else {
// Need to insert a copy - first check can be copied
let wire_kind = base.hugr().get_optype(src).port_kind(src_offset);
let Some(EdgeKind::Value(simple_type)) = wire_kind else {panic!("Wires can only be Value kind")};
src_port = 0;
}
}

let typ = match simple_type {
SimpleType::Classic(typ) => typ,
SimpleType::Linear(typ) => return Err(BuildError::NoCopyLinear(typ)),
};
if let Some((connected, connected_offset)) = base.hugr().linked_port(src, src_offset) {
if let Some(copy_port) = if_copy_add_port(base, src) {
src_port = copy_port;
src = connected;
}
// Need to insert a copy - first check can be copied
else if let Some(typ) = check_classical_value(base, src, src_offset)? {
// TODO API consistency in using PortOffset vs. usize
base.disconnect(src, src_port, Direction::Outgoing)?;

Expand All @@ -462,10 +489,51 @@ fn wire_up<T: Dataflow + ?Sized>(
src_port = 1;
}
}
data_builder
.base()
.connect(src, src_port, op_node, dst_port)?;
Ok(())
data_builder.base().connect(src, src_port, dst, dst_port)?;
Ok(local_source)
}

/// Check the kind of a port is a classical Value and return it
/// Return None if Const kind
/// Panics if port not valid for Op or port is not Const/Value
fn check_classical_value(
base: &HugrMut,
src: NodeIndex,
src_offset: PortOffset,
) -> Result<Option<ClassicType>, BuildError> {
let wire_kind = base.hugr().get_optype(src).port_kind(src_offset).unwrap();
let typ = match wire_kind {
EdgeKind::Const(_) => None,
EdgeKind::Value(simple_type) => match simple_type {
SimpleType::Classic(typ) => Some(typ),
SimpleType::Linear(typ) => return Err(BuildError::NoCopyLinear(typ)),
},
_ => {
panic!("Wires can only be Value kind")
}
};

Ok(typ)
}

// Return newly added port to copy node if src node is a copy
fn if_copy_add_port(base: &mut HugrMut, src: NodeIndex) -> Option<usize> {
let src_op: Result<&LeafOp, ()> = base.hugr().get_optype(src).try_into();
if let Ok(LeafOp::Copy { n_copies, typ }) = src_op {
let copy_node = src;
// If already connected to a copy node, add wire to the copy
let n_copies = *n_copies;
base.replace_op(
copy_node,
LeafOp::Copy {
n_copies: n_copies + 1,
typ: typ.clone(),
},
);
base.add_ports(copy_node, Direction::Outgoing, 1).next()
} else {
None
}
}

impl<'f> DeltaBuilder<'f> {
Expand Down Expand Up @@ -1038,6 +1106,7 @@ mod test {

func_builder.finish_with_outputs(kappa_id.outputs())?
};

module_builder.finish()
};

Expand Down Expand Up @@ -1284,4 +1353,31 @@ mod test {

assert_eq!(builder(), Err(BuildError::NoCopyLinear(LinearType::Qubit)));
}

#[test]
fn simple_inter_graph_edge() {
let builder = || {
let mut module_builder = ModuleBuilder::new();

let mut f_build =
module_builder.declare_and_def("main", type_row![BIT], type_row![BIT])?;

let [i1] = f_build.input_wires_arr();
let noop = f_build.add_dataflow_op(LeafOp::Noop(BIT), [i1])?;
let i1 = noop.out_wire(0);

let mut nested = f_build.delta_builder(vec![], type_row![BIT])?;

let id = nested.add_dataflow_op(LeafOp::Noop(BIT), [i1])?;

let nested = nested.finish_with_outputs([id.out_wire(0)])?;

f_build.finish_with_outputs([nested.out_wire(0)])?;

crate::utils::test::viz_dotstr(&module_builder.hugr().dot_string());
module_builder.finish()
};

assert_matches!(builder(), Ok(_));
}
}

0 comments on commit 2527c16

Please sign in to comment.