Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- **Circuit Conversion**: Added circuit-derived pre-scheduling support in `circuit2graph()`.
- Added `CircuitScheduleStrategy` with `PARALLEL` and `MINIMIZE_SPACE`.
- Added `schedule_strategy` argument to `circuit2graph()`.
- `circuit2graph()` now returns `(graph, gflow, scheduler)` and pre-populates `Scheduler` via manual scheduling.

### Changed

- **Graph State**: Made `meas_bases` read-only by returning `MappingProxyType` to avoid external mutation.
- **Graph State**: Added caching for `physical_nodes` snapshots and proper cache invalidation on node add/remove.
- **Docs/Examples**: Updated circuit conversion usage in README and `examples/pattern_from_circuit.py` for the new `circuit2graph()` return signature.

### Fixed

- **Feedforward**: Fixed self-loop removal in `dag_from_flow()` by correcting operator precedence so self-loops are removed from combined `xflow`/`zflow` dependencies.
- **Pauli Frame**: Initialize `_pauli_axis_cache` only when FTQC parity-check groups are provided, avoiding unnecessary cache creation in non-FTQC usage.

### Tests

- **Circuit Conversion**: Expanded scheduling tests in `tests/test_circuit.py`, including scheduler return contract, J/CZ/phase-gadget timing behavior, schedule validation, and `MINIMIZE_SPACE` behavior.
- **Integration**: Added circuit-level integration tests for `signal_shifting()` and `pauli_simplification()` with circuit-vs-pattern statevector equivalence checks.
- **Stim Compiler / Pauli Frame**: Updated tests to explicitly pass parity-check groups where logical-observable and cache initialization paths are exercised.

## [0.2.1] - 2026-01-16

### Added
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ circuit = Circuit(2)
circuit.apply_macro_gate(H(0))
circuit.apply_macro_gate(CNOT((0, 1)))

graph, feedforward = circuit2graph(circuit)
graph, feedforward, scheduler = circuit2graph(circuit)

# Compile into pattern
pattern = qompile(graph, feedforward)
Expand Down
4 changes: 2 additions & 2 deletions examples/pattern_from_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
circuit.cz(1, 2)

# %%
# convert circuit to graph and flow
graphstate, gflow = circuit2graph(circuit)
# convert circuit to graph, flow, and scheduler
graphstate, gflow, scheduler = circuit2graph(circuit)

# first, qompile it to standardized pattern
pattern = qompile(graphstate, gflow)
Expand Down
178 changes: 137 additions & 41 deletions graphqomb/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,137 @@
- `BaseCircuit`: An abstract base class for quantum circuits.
- `MBQCCircuit`: A circuit class composed solely of a unit gate set.
- `Circuit`: A class for circuits that include macro instructions.
- `circuit2graph`: A function that converts a circuit to a graph state and gflow.
- `CircuitScheduleStrategy`: Scheduling strategies for circuit conversion.
- `circuit2graph`: A function that converts a circuit to a graph state, gflow, and scheduler.
"""

from __future__ import annotations

import copy
import enum
import itertools
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING

import typing_extensions

from graphqomb.common import Plane, PlannerMeasBasis
from graphqomb.gates import CZ, Gate, J, PhaseGadget, UnitGate
from graphqomb.graphstate import GraphState
from graphqomb.scheduler import Scheduler

if TYPE_CHECKING:
from collections.abc import Sequence


class CircuitScheduleStrategy(Enum):
"""Enumeration for manual scheduling strategies derived from circuit structure."""

PARALLEL = enum.auto()
MINIMIZE_SPACE = enum.auto()


class _Circuit2GraphContext:
"""Internal helper for converting circuits with a given scheduling strategy."""

graph: GraphState
gflow: dict[int, set[int]]
qindex2front_nodes: dict[int, int]
qindex2timestep: dict[int, int]
prepare_time: dict[int, int]
measure_time: dict[int, int]
minimize_qubits: bool
current_time: int

def __init__(self, graph: GraphState, strategy: CircuitScheduleStrategy) -> None:
if strategy == CircuitScheduleStrategy.PARALLEL:
self.minimize_qubits = False
elif strategy == CircuitScheduleStrategy.MINIMIZE_SPACE:
self.minimize_qubits = True
else:
msg = f"Invalid schedule strategy: {strategy}"
raise ValueError(msg)

self.graph = graph
self.gflow = {}
self.qindex2front_nodes = {}
self.qindex2timestep = {}
self.prepare_time = {}
self.measure_time = {}
self.current_time = 0

def apply_instruction(self, instruction: UnitGate) -> None:
"""Apply a unit gate to the graph conversion context.

Raises
------
TypeError
If the instruction type is not supported.
"""
if isinstance(instruction, J):
self._apply_j(instruction)
return
if isinstance(instruction, CZ):
self._apply_cz(instruction)
return
if isinstance(instruction, PhaseGadget):
self._apply_phase_gadget(instruction)
return
msg = f"Invalid instruction: {instruction}"
raise TypeError(msg)

def _apply_j(self, instruction: J) -> None:
new_node = self.graph.add_physical_node()
self.graph.add_physical_edge(self.qindex2front_nodes[instruction.qubit], new_node)
self.graph.assign_meas_basis(
self.qindex2front_nodes[instruction.qubit],
PlannerMeasBasis(Plane.XY, -instruction.angle),
)

timestep = self.qindex2timestep[instruction.qubit]
if self.minimize_qubits:
timestep = max(self.current_time, timestep)
self.prepare_time[new_node] = timestep
self.measure_time[self.qindex2front_nodes[instruction.qubit]] = timestep + 1
self.qindex2timestep[instruction.qubit] = timestep + 1
if self.minimize_qubits:
self.current_time = timestep + 1

self.gflow[self.qindex2front_nodes[instruction.qubit]] = {new_node}
self.qindex2front_nodes[instruction.qubit] = new_node

def _apply_cz(self, instruction: CZ) -> None:
self.graph.add_physical_edge(
self.qindex2front_nodes[instruction.qubits[0]],
self.qindex2front_nodes[instruction.qubits[1]],
)

aligned_time = max(self.qindex2timestep[instruction.qubits[0]], self.qindex2timestep[instruction.qubits[1]])
if self.minimize_qubits:
aligned_time = max(self.current_time, aligned_time)
self.current_time = aligned_time
self.qindex2timestep[instruction.qubits[0]] = aligned_time
self.qindex2timestep[instruction.qubits[1]] = aligned_time

def _apply_phase_gadget(self, instruction: PhaseGadget) -> None:
new_node = self.graph.add_physical_node()
self.graph.assign_meas_basis(new_node, PlannerMeasBasis(Plane.YZ, instruction.angle))
for qubit in instruction.qubits:
self.graph.add_physical_edge(self.qindex2front_nodes[qubit], new_node)

self.gflow[new_node] = {new_node}

max_timestep = max(self.qindex2timestep[qubit] for qubit in instruction.qubits)
if self.minimize_qubits:
max_timestep = max(self.current_time, max_timestep)
self.current_time = max_timestep + 1
self.prepare_time[new_node] = max_timestep
self.measure_time[new_node] = max_timestep + 1
for qubit in instruction.qubits:
self.qindex2timestep[qubit] = max_timestep + 1


class BaseCircuit(ABC):
"""
Abstract base class for quantum circuits.
Expand Down Expand Up @@ -208,64 +319,49 @@ def apply_macro_gate(self, gate: Gate) -> None:
self.__macro_gate_instructions.append(gate)


def circuit2graph(circuit: BaseCircuit) -> tuple[GraphState, dict[int, set[int]]]:
r"""Convert a circuit to a graph state and gflow.
def circuit2graph(
circuit: BaseCircuit,
schedule_strategy: CircuitScheduleStrategy = CircuitScheduleStrategy.PARALLEL,
) -> tuple[GraphState, dict[int, set[int]], Scheduler]:
r"""Convert a circuit to a graph state, gflow, and scheduler.

Parameters
----------
circuit : `BaseCircuit`
The quantum circuit to convert.
schedule_strategy : `CircuitScheduleStrategy`, optional
Strategy for scheduling preparation and measurement times derived from the circuit,
by default `CircuitScheduleStrategy.PARALLEL`.
The strategies are:

- `CircuitScheduleStrategy.PARALLEL`: schedule each qubit independently to reduce depth
- `CircuitScheduleStrategy.MINIMIZE_SPACE`: serialize operations to reduce prepared qubits

Returns
-------
`tuple`\[`GraphState`, `dict`\[`int`, `set`\[`int`\]\]\]
The graph state and gflow converted from the circuit.
`tuple`\[`GraphState`, `dict`\[`int`, `set`\[`int`\]\], `Scheduler`\]
The graph state, gflow, and scheduler converted from the circuit.
The scheduler is configured with automatic time scheduling derived from circuit structure.

Raises
------
TypeError
If the circuit contains an invalid instruction.
"""
graph = GraphState()
gflow: dict[int, set[int]] = {}

qindex2front_nodes: dict[int, int] = {}
context = _Circuit2GraphContext(graph, schedule_strategy)

# input nodes
for i in range(circuit.num_qubits):
node = graph.add_physical_node()
graph.register_input(node, i)
qindex2front_nodes[i] = node
context.qindex2front_nodes[i] = node
context.qindex2timestep[i] = 0

for instruction in circuit.unit_instructions():
if isinstance(instruction, J):
new_node = graph.add_physical_node()
graph.add_physical_edge(qindex2front_nodes[instruction.qubit], new_node)
graph.assign_meas_basis(
qindex2front_nodes[instruction.qubit],
PlannerMeasBasis(Plane.XY, -instruction.angle),
)

gflow[qindex2front_nodes[instruction.qubit]] = {new_node}
qindex2front_nodes[instruction.qubit] = new_node

elif isinstance(instruction, CZ):
graph.add_physical_edge(
qindex2front_nodes[instruction.qubits[0]],
qindex2front_nodes[instruction.qubits[1]],
)
elif isinstance(instruction, PhaseGadget):
new_node = graph.add_physical_node()
graph.assign_meas_basis(new_node, PlannerMeasBasis(Plane.YZ, instruction.angle))
for qubit in instruction.qubits:
graph.add_physical_edge(qindex2front_nodes[qubit], new_node)

gflow[new_node] = {new_node}
else:
msg = f"Invalid instruction: {instruction}"
raise TypeError(msg)
context.apply_instruction(instruction)

for qindex, node in qindex2front_nodes.items():
for qindex, node in context.qindex2front_nodes.items():
graph.register_output(node, qindex)

return graph, gflow
# manually schedule
scheduler = Scheduler(graph, context.gflow)
scheduler.manual_schedule(context.prepare_time, context.measure_time)

return graph, context.gflow, scheduler
2 changes: 1 addition & 1 deletion graphqomb/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def dag_from_flow(
msg = "Invalid zflow object"
raise TypeError(msg)
for node in non_output_nodes:
target_nodes = xflow.get(node, set()) | zflow.get(node, set()) - {node} # remove self-loops
target_nodes = (xflow.get(node, set()) | zflow.get(node, set())) - {node} # remove self-loops
dag[node] = target_nodes
for output in output_nodes:
dag[output] = set()
Expand Down
22 changes: 15 additions & 7 deletions graphqomb/graphstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from abc import ABC
from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Set as AbstractSet
from types import MappingProxyType
from typing import TYPE_CHECKING, NamedTuple, TypeVar

import typing_extensions
Expand Down Expand Up @@ -83,12 +84,12 @@ def physical_edges(self) -> set[tuple[int, int]]:

@property
@abc.abstractmethod
def meas_bases(self) -> dict[int, MeasBasis]:
def meas_bases(self) -> MappingProxyType[int, MeasBasis]:
r"""Return measurement bases.

Returns
-------
`dict`\[`int`, `MeasBasis`\]
`types.MappingProxyType`\[`int`, `MeasBasis`\]
measurement bases of each physical node.
"""

Expand Down Expand Up @@ -199,6 +200,8 @@ class GraphState(BaseGraphState):

__node_counter: int

_cached_physical_nodes: frozenset[int] | None = None

def __init__(self) -> None:
self.__input_node_indices = {}
self.__output_node_indices = {}
Expand Down Expand Up @@ -244,7 +247,9 @@ def physical_nodes(self) -> set[int]:
`set`\[`int`\]
set of physical nodes.
"""
return self.__physical_nodes.copy()
if self._cached_physical_nodes is None:
self._cached_physical_nodes = frozenset(self.__physical_nodes)
return set(self._cached_physical_nodes)

@property
@typing_extensions.override
Expand All @@ -265,15 +270,15 @@ def physical_edges(self) -> set[tuple[int, int]]:

@property
@typing_extensions.override
def meas_bases(self) -> dict[int, MeasBasis]:
def meas_bases(self) -> MappingProxyType[int, MeasBasis]:
r"""Return measurement bases.

Returns
-------
`dict`\[`int`, `MeasBasis`\]
`types.MappingProxyType`\[`int`, `MeasBasis`\]
measurement bases of each physical node.
"""
return self.__meas_bases.copy()
return MappingProxyType(self.__meas_bases)

@property
def local_cliffords(self) -> dict[int, LocalClifford]:
Expand Down Expand Up @@ -356,6 +361,7 @@ def add_physical_node(self, coordinate: tuple[float, ...] | None = None) -> int:
if coordinate is not None:
self.__coordinates[node] = coordinate
self.__node_counter += 1
self._cached_physical_nodes = None

return node

Expand Down Expand Up @@ -416,6 +422,8 @@ def remove_physical_node(self, node: int) -> None:
self.__local_cliffords.pop(node, None)
self.__coordinates.pop(node, None)

self._cached_physical_nodes = None

def remove_physical_edge(self, node1: int, node2: int) -> None:
"""Remove a physical edge from the graph state.

Expand Down Expand Up @@ -561,7 +569,7 @@ def check_canonical_form(self) -> None:
if self.__local_cliffords:
msg = "Clifford operators are applied."
raise ValueError(msg)
for node in self.physical_nodes - set(self.output_node_indices):
for node in self.physical_nodes - self.output_node_indices.keys():
if self.meas_bases.get(node) is None:
msg = "All non-output nodes must have measurement basis."
raise ValueError(msg)
Expand Down
8 changes: 5 additions & 3 deletions graphqomb/pauli_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def __init__(
# Pre-compute Pauli axes for performance optimization
# Only cache nodes that have measurement bases
# NOTE: if non-Pauli measurements are involved, the stim_compile func will error out earlier
self._pauli_axis_cache = {
node: determine_pauli_axis(meas_basis) for node, meas_basis in graphstate.meas_bases.items()
}
self._pauli_axis_cache = (
{node: determine_pauli_axis(meas_basis) for node, meas_basis in graphstate.meas_bases.items()}
if parity_check_group
else {}
Comment on lines +86 to +89

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Populate Pauli-axis cache even without parity groups

Because _pauli_axis_cache is now only built when parity_check_group is non-empty, any call to PauliFrame.logical_observables_group() (used by stim_compile when logical_observables is provided) will hit self._pauli_axis_cache[node] and raise a KeyError if parity_check_group was omitted. This makes stim_compile(..., logical_observables=...) crash in the common case where callers don’t provide parity check groups, even though the API doesn’t require them.

Useful? React with 👍 / 👎.

) # only necessary for FTQC
# Cache for memoization of dependent chains
self._chain_cache = {}

Expand Down
Loading