Skip to content

Commit f0385a0

Browse files
authored
feat: add wrapper for tagging Some, Left, Right, Break, Continue (#1814)
Closes #1808
1 parent ab94518 commit f0385a0

File tree

3 files changed

+74
-8
lines changed

3 files changed

+74
-8
lines changed

hugr-py/src/hugr/ops.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,61 @@ def __repr__(self) -> str:
556556
return f"Tag({self.tag})"
557557

558558

559+
@dataclass
560+
class Some(Tag):
561+
"""Tag operation for the `Some` variant of an Option type.
562+
563+
Example:
564+
# construct a Some variant holding a row of Bool and Unit types
565+
>>> Some(tys.Bool, tys.Unit)
566+
Some
567+
"""
568+
569+
def __init__(self, *some_tys: tys.Type) -> None:
570+
super().__init__(1, tys.Option(*some_tys))
571+
572+
def __repr__(self) -> str:
573+
return "Some"
574+
575+
576+
@dataclass
577+
class Right(Tag):
578+
"""Tag operation for the `Right` variant of an type."""
579+
580+
def __init__(self, either_type: tys.Either) -> None:
581+
super().__init__(1, either_type)
582+
583+
def __repr__(self) -> str:
584+
return "Right"
585+
586+
587+
@dataclass
588+
class Left(Tag):
589+
"""Tag operation for the `Left` variant of an type."""
590+
591+
def __init__(self, either_type: tys.Either) -> None:
592+
super().__init__(0, either_type)
593+
594+
def __repr__(self) -> str:
595+
return "Left"
596+
597+
598+
class Continue(Left):
599+
"""Tag operation for the `Continue` variant of a TailLoop
600+
controlling Either type.
601+
"""
602+
603+
def __repr__(self) -> str:
604+
return "Continue"
605+
606+
607+
class Break(Right):
608+
"""Tag operation for the `Break` variant of a TailLoop controlling Either type."""
609+
610+
def __repr__(self) -> str:
611+
return "Break"
612+
613+
559614
class DfParentOp(Op, Protocol):
560615
"""Abstract parent of dataflow graph operations. Can be queried for the
561616
dataflow signature of its child graph.

hugr-py/tests/test_cond_loop.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .conftest import QUANTUM_EXT, H, Measure, validate
1010

11-
SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]])
11+
EITHER_T = tys.Either([tys.Qubit], [tys.Qubit, INT_T])
1212

1313

1414
def build_cond(h: Conditional) -> None:
@@ -25,15 +25,15 @@ def build_cond(h: Conditional) -> None:
2525

2626

2727
def test_cond() -> None:
28-
h = Conditional(SUM_T, [tys.Bool])
28+
h = Conditional(EITHER_T, [tys.Bool])
2929
build_cond(h)
3030
validate(h.hugr)
3131

3232

3333
def test_nested_cond() -> None:
3434
h = Dfg(tys.Qubit)
3535
(q,) = h.inputs()
36-
tagged_q = h.add(ops.Tag(0, SUM_T)(q))
36+
tagged_q = h.add(ops.Left(EITHER_T)(q))
3737

3838
with h.add_conditional(tagged_q, h.load(val.TRUE)) as cond:
3939
build_cond(cond)
@@ -42,12 +42,12 @@ def test_nested_cond() -> None:
4242
validate(h.hugr)
4343

4444
# build then insert
45-
con = Conditional(SUM_T, [tys.Bool])
45+
con = Conditional(EITHER_T, [tys.Bool])
4646
build_cond(con)
4747

4848
h = Dfg(tys.Qubit)
4949
(q,) = h.inputs()
50-
tagged_q = h.add(ops.Tag(0, SUM_T)(q))
50+
tagged_q = h.add(ops.Left(EITHER_T)(q))
5151
cond_n = h.insert_conditional(con, tagged_q, h.load(val.TRUE))
5252
h.set_outputs(*cond_n[:2])
5353
validate(h.hugr)
@@ -70,7 +70,7 @@ def test_if_else() -> None:
7070

7171
def test_incomplete() -> None:
7272
def _build_incomplete():
73-
with Conditional(SUM_T, [tys.Bool]) as c, c.add_case(0) as case0:
73+
with Conditional(EITHER_T, [tys.Bool]) as c, c.add_case(0) as case0:
7474
q, b = case0.inputs()
7575
case0.set_outputs(q, b)
7676

@@ -118,13 +118,13 @@ def test_complex_tail_loop() -> None:
118118
# if b is true, return first variant (just qubit)
119119
with tl.add_if(b, q) as if_:
120120
(q,) = if_.inputs()
121-
tagged_q = if_.add(ops.Tag(0, SUM_T)(q))
121+
tagged_q = if_.add(ops.Continue(EITHER_T)(q))
122122
if_.set_outputs(tagged_q)
123123

124124
# else return second variant (qubit, int)
125125
with if_.add_else() as else_:
126126
(q,) = else_.inputs()
127-
tagged_q_i = else_.add(ops.Tag(1, SUM_T)(q, else_.load(IntVal(1))))
127+
tagged_q_i = else_.add(ops.Break(EITHER_T)(q, else_.load(IntVal(1))))
128128
else_.set_outputs(tagged_q_i)
129129

130130
# finish with Sum output from if-else, and bool from inputs

hugr-py/tests/test_hugr_build.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,14 @@ def test_dfg_unpack() -> None:
330330
dfg.set_outputs(*cond.outputs())
331331

332332
validate(dfg.hugr)
333+
334+
335+
def test_option() -> None:
336+
dfg = Dfg(tys.Bool)
337+
b = dfg.inputs()[0]
338+
339+
dfg.add_op(ops.Some(tys.Bool), b)
340+
341+
dfg.set_outputs(b)
342+
343+
validate(dfg.hugr)

0 commit comments

Comments
 (0)