Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion hugr-core/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod module;
pub mod sum;
pub mod tag;
pub mod validate;
use crate::core::HugrNode;
use crate::extension::resolution::{
collect_op_extension, collect_op_types_extensions, ExtensionCollectionError,
};
Expand All @@ -20,6 +21,7 @@ use crate::types::{EdgeKind, Signature, Substitution};
use crate::{Direction, OutgoingPort, Port};
use crate::{IncomingPort, PortIndex};
use derive_more::Display;
use handle::NodeHandle;
use paste::paste;
use portgraph::NodeIndex;

Expand All @@ -41,7 +43,6 @@ pub use tag::OpTag;
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
/// The concrete operation types for a node in the HUGR.
// TODO: Link the NodeHandles to the OpType.
#[non_exhaustive]
#[allow(missing_docs)]
#[serde(tag = "op")]
Expand Down Expand Up @@ -377,6 +378,19 @@ pub trait OpTrait: Sized + Clone {
/// Tag identifying the operation.
fn tag(&self) -> OpTag;

/// Tries to create a specific [`NodeHandle`] for a node with this operation
/// type.
///
/// Fails if the operation's [`OpTrait::tag`] does not match the
/// [`NodeHandle::TAG`] of the requested handle.
fn try_node_handle<N, H>(&self, node: N) -> Option<H>
where
N: HugrNode,
H: NodeHandle<N> + From<N>,
{
H::TAG.is_superset(self.tag()).then(|| node.into())
}

/// The signature of the operation.
///
/// Only dataflow operations have a signature, otherwise returns None.
Expand Down
73 changes: 39 additions & 34 deletions hugr-core/src/ops/handle.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Handles to nodes in HUGR.
use crate::core::HugrNode;
use crate::types::{Type, TypeBound};
use crate::Node;

Expand All @@ -9,12 +10,12 @@ use super::{AliasDecl, OpTag};

/// Common trait for handles to a node.
/// Typically wrappers around [`Node`].
pub trait NodeHandle: Clone {
pub trait NodeHandle<N = Node>: Clone {
/// The most specific operation tag associated with the handle.
const TAG: OpTag;

/// Index of underlying node.
fn node(&self) -> Node;
fn node(&self) -> N;

/// Operation tag for the handle.
#[inline]
Expand All @@ -23,7 +24,7 @@ pub trait NodeHandle: Clone {
}

/// Cast the handle to a different more general tag.
fn try_cast<T: NodeHandle + From<Node>>(&self) -> Option<T> {
fn try_cast<T: NodeHandle<N> + From<N>>(&self) -> Option<T> {
T::TAG.is_superset(Self::TAG).then(|| self.node().into())
}

Expand All @@ -36,54 +37,54 @@ pub trait NodeHandle: Clone {
/// Trait for handles that contain children.
///
/// The allowed children handles are defined by the associated type.
pub trait ContainerHandle: NodeHandle {
pub trait ContainerHandle<N = Node>: NodeHandle<N> {
/// Handle type for the children of this node.
type ChildrenHandle: NodeHandle;
type ChildrenHandle: NodeHandle<N>;
}

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [DataflowOp](crate::ops::dataflow).
pub struct DataflowOpID(Node);
pub struct DataflowOpID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [DFG](crate::ops::DFG) node.
pub struct DfgID(Node);
pub struct DfgID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [CFG](crate::ops::CFG) node.
pub struct CfgID(Node);
pub struct CfgID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a module [Module](crate::ops::Module) node.
pub struct ModuleRootID(Node);
pub struct ModuleRootID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [module op](crate::ops::module) node.
pub struct ModuleID(Node);
pub struct ModuleID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [def](crate::ops::OpType::FuncDefn)
/// or [declare](crate::ops::OpType::FuncDecl) node.
///
/// The `DEF` const generic is used to indicate whether the function is
/// defined or just declared.
pub struct FuncID<const DEF: bool>(Node);
pub struct FuncID<const DEF: bool, N = Node>(N);

#[derive(Debug, Clone, PartialEq, Eq)]
/// Handle to an [AliasDefn](crate::ops::OpType::AliasDefn)
/// or [AliasDecl](crate::ops::OpType::AliasDecl) node.
///
/// The `DEF` const generic is used to indicate whether the function is
/// defined or just declared.
pub struct AliasID<const DEF: bool> {
node: Node,
pub struct AliasID<const DEF: bool, N = Node> {
node: N,
name: SmolStr,
bound: TypeBound,
}

impl<const DEF: bool> AliasID<DEF> {
impl<const DEF: bool, N> AliasID<DEF, N> {
/// Construct new AliasID
pub fn new(node: Node, name: SmolStr, bound: TypeBound) -> Self {
pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self {
Self { node, name, bound }
}

Expand All @@ -99,27 +100,27 @@ impl<const DEF: bool> AliasID<DEF> {

#[derive(DerFrom, Debug, Clone, PartialEq, Eq)]
/// Handle to a [Const](crate::ops::OpType::Const) node.
pub struct ConstID(Node);
pub struct ConstID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node.
pub struct BasicBlockID(Node);
pub struct BasicBlockID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [Case](crate::ops::Case) node.
pub struct CaseID(Node);
pub struct CaseID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [TailLoop](crate::ops::TailLoop) node.
pub struct TailLoopID(Node);
pub struct TailLoopID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a [Conditional](crate::ops::Conditional) node.
pub struct ConditionalID(Node);
pub struct ConditionalID<N = Node>(N);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
/// Handle to a dataflow container node.
pub struct DataflowParentID(Node);
pub struct DataflowParentID<N = Node>(N);

/// Implements the `NodeHandle` trait for a tuple struct that contains just a
/// NodeIndex. Takes the name of the struct, and the corresponding OpTag.
Expand All @@ -131,11 +132,11 @@ macro_rules! impl_nodehandle {
impl_nodehandle!($name, $tag, 0);
};
($name:ident, $tag:expr, $node_attr:tt) => {
impl NodeHandle for $name {
impl<N: HugrNode> NodeHandle<N> for $name<N> {
const TAG: OpTag = $tag;

#[inline]
fn node(&self) -> Node {
fn node(&self) -> N {
self.$node_attr
}
}
Expand All @@ -156,35 +157,35 @@ impl_nodehandle!(ConstID, OpTag::Const);

impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock);

impl<const DEF: bool> NodeHandle for FuncID<DEF> {
impl<const DEF: bool, N: HugrNode> NodeHandle<N> for FuncID<DEF, N> {
const TAG: OpTag = OpTag::Function;
#[inline]
fn node(&self) -> Node {
fn node(&self) -> N {
self.0
}
}

impl<const DEF: bool> NodeHandle for AliasID<DEF> {
impl<const DEF: bool, N: HugrNode> NodeHandle<N> for AliasID<DEF, N> {
const TAG: OpTag = OpTag::Alias;
#[inline]
fn node(&self) -> Node {
fn node(&self) -> N {
self.node
}
}

impl NodeHandle for Node {
impl<N: HugrNode> NodeHandle<N> for N {
const TAG: OpTag = OpTag::Any;
#[inline]
fn node(&self) -> Node {
fn node(&self) -> N {
*self
}
}

/// Implements the `ContainerHandle` trait, with the given child handle type.
macro_rules! impl_containerHandle {
($name:path, $children:ident) => {
impl ContainerHandle for $name {
type ChildrenHandle = $children;
($name:ident, $children:ident) => {
impl<N: HugrNode> ContainerHandle<N> for $name<N> {
type ChildrenHandle = $children<N>;
}
};
}
Expand All @@ -197,5 +198,9 @@ impl_containerHandle!(CaseID, DataflowOpID);
impl_containerHandle!(ModuleRootID, ModuleID);
impl_containerHandle!(CfgID, BasicBlockID);
impl_containerHandle!(BasicBlockID, DataflowOpID);
impl_containerHandle!(FuncID<true>, DataflowOpID);
impl_containerHandle!(AliasID<true>, DataflowOpID);
impl<N: HugrNode> ContainerHandle<N> for FuncID<true, N> {
type ChildrenHandle = DataflowOpID<N>;
}
impl<N: HugrNode> ContainerHandle<N> for AliasID<true, N> {
type ChildrenHandle = DataflowOpID<N>;
}
Loading