Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- **Graph State**: Bulk initialization methods for GraphState ([#120](https://github.com/TeamGraphix/graphqomb/issues/120))
- Added `from_graph()` class method for direct graph-based initialization
- Added `from_base_graph_state()` class method for initialization from base GraphState objects
- Improved initialization flexibility for diverse use cases

### Performance

- **Pauli Frame**: Optimized `_collect_dependent_chain` method with memoization and caching
Expand All @@ -22,6 +29,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added tests for dependent chain collection across X, Y, Z measurement axes
- Added tests for detector groups and logical observables
- Improved test coverage from 77.78% to 97% for pauli_frame.py
- **Graph State**: Added comprehensive test suite for bulk initialization methods
- Added tests for `from_graph()` initialization
- Added tests for `from_base_graph_state()` initialization
- Added tests for graph consistency and state equivalence

## [0.1.1] - 2025-10-23

Expand Down
21 changes: 7 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,13 @@ print(simulator.state)
from graphqomb.graphstate import GraphState
from graphqomb.visualizer import visualize

# Create a graph state
graph = GraphState()
node1 = graph.add_physical_node()
node2 = graph.add_physical_node()
node3 = graph.add_physical_node()

# Register input/output nodes
q_index = 0
graph.register_input(node1, q_index)
graph.register_output(node3, q_index)

# Add edges
graph.add_physical_edge(node1, node2)
graph.add_physical_edge(node2, node3)
# Create a graph state using from_graph
graph, node_map = GraphState.from_graph(
nodes=["input", "middle", "output"],
edges=[("input", "middle"), ("middle", "output")],
inputs=["input"],
outputs=["output"]
)

# Visualize the graph
visualize(graph)
Expand Down
123 changes: 56 additions & 67 deletions examples/draw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import matplotlib.pyplot as plt
import numpy as np

from graphqomb.common import Axis, AxisMeasBasis, Plane, PlannerMeasBasis, Sign
from graphqomb.common import Axis, AxisMeasBasis, MeasBasis, Plane, PlannerMeasBasis, Sign
from graphqomb.graphstate import GraphState
from graphqomb.random_objects import generate_random_flow_graph
from graphqomb.visualizer import visualize
Expand All @@ -26,44 +26,32 @@
# %%
# Create a demo graph with different measurement planes and input/output nodes

demo_graph = GraphState()

# Add input nodes
input_node1 = demo_graph.add_physical_node()
input_node2 = demo_graph.add_physical_node()
demo_graph.register_input(input_node1, 0)
demo_graph.register_input(input_node2, 1)

# Set measurement bases for input nodes (XY plane with different angles)
demo_graph.assign_meas_basis(input_node1, AxisMeasBasis(Axis.X, Sign.PLUS))
demo_graph.assign_meas_basis(input_node2, PlannerMeasBasis(Plane.XY, np.pi / 6))

# Add internal nodes with different measurement planes
internal_node1 = demo_graph.add_physical_node()
internal_node2 = demo_graph.add_physical_node()
internal_node3 = demo_graph.add_physical_node()

# Set measurement bases for internal nodes
# XZ plane (blue) with angle π/4
demo_graph.assign_meas_basis(internal_node1, PlannerMeasBasis(Plane.XZ, np.pi / 4))
# YZ plane (red) with angle π/3
demo_graph.assign_meas_basis(internal_node2, PlannerMeasBasis(Plane.YZ, np.pi / 3))
# XZ plane (blue) with angle π/2
demo_graph.assign_meas_basis(internal_node3, PlannerMeasBasis(Plane.XZ, np.pi / 2))

# Add output nodes
output_node1 = demo_graph.add_physical_node()
output_node2 = demo_graph.add_physical_node()
demo_graph.register_output(output_node1, 0)
demo_graph.register_output(output_node2, 1)

# Create edges to connect the graph
demo_graph.add_physical_edge(input_node1, internal_node1)
demo_graph.add_physical_edge(input_node2, internal_node2)
demo_graph.add_physical_edge(internal_node1, internal_node3)
demo_graph.add_physical_edge(internal_node2, internal_node3)
demo_graph.add_physical_edge(internal_node3, output_node1)
demo_graph.add_physical_edge(internal_node1, output_node2)
# Define graph structure with named nodes
nodes = ["input1", "input2", "internal1", "internal2", "internal3", "output1", "output2"]
edges = [
("input1", "internal1"),
("input2", "internal2"),
("internal1", "internal3"),
("internal2", "internal3"),
("internal3", "output1"),
("input1", "output2"),
]
inputs = ["input1", "input2"]
outputs = ["output1", "output2"]

# Define measurement bases for nodes
meas_bases: dict[str, MeasBasis] = {
"input1": AxisMeasBasis(Axis.X, Sign.PLUS),
"input2": PlannerMeasBasis(Plane.XY, np.pi / 6),
"internal1": PlannerMeasBasis(Plane.XZ, np.pi / 4), # XZ plane with angle π/4
"internal2": PlannerMeasBasis(Plane.YZ, np.pi / 3), # YZ plane with angle π/3
"internal3": PlannerMeasBasis(Plane.XZ, np.pi / 2), # XZ plane with angle π/2
}

# Create graph state from structure
demo_graph, node_map = GraphState.from_graph(
nodes=nodes, edges=edges, inputs=inputs, outputs=outputs, meas_bases=meas_bases
)

print("Demo graph with XZ and YZ measurement planes:")
print(f"Input nodes: {list(demo_graph.input_node_indices.keys())}")
Expand All @@ -86,34 +74,35 @@

# %%
# Create another demo graph with Pauli measurements (θ=0, π)
pauli_demo_graph = GraphState()

# Add nodes for Pauli measurements
pauli_input = pauli_demo_graph.add_physical_node()
pauli_demo_graph.register_input(pauli_input, 0)

# Create internal nodes with Pauli measurements
x_measurement_node = pauli_demo_graph.add_physical_node() # X measurement: XY plane, θ=0
y_measurement_node = pauli_demo_graph.add_physical_node() # Y measurement: YZ plane, θ=π/2
z_measurement_node = pauli_demo_graph.add_physical_node() # Z measurement: XZ plane, θ=π

# Set Pauli measurement bases
pauli_demo_graph.assign_meas_basis(pauli_input, AxisMeasBasis(Axis.X, Sign.PLUS)) # X+
pauli_demo_graph.assign_meas_basis(x_measurement_node, AxisMeasBasis(Axis.X, Sign.PLUS)) # X+
pauli_demo_graph.assign_meas_basis(y_measurement_node, AxisMeasBasis(Axis.Y, Sign.PLUS)) # Y+
pauli_demo_graph.assign_meas_basis(z_measurement_node, AxisMeasBasis(Axis.Z, Sign.MINUS)) # Z-

# Add output node
pauli_output = pauli_demo_graph.add_physical_node()
pauli_demo_graph.register_output(pauli_output, 0)

# Connect nodes
pauli_demo_graph.add_physical_edge(pauli_input, x_measurement_node)
pauli_demo_graph.add_physical_edge(x_measurement_node, y_measurement_node)
pauli_demo_graph.add_physical_edge(y_measurement_node, z_measurement_node)
pauli_demo_graph.add_physical_edge(z_measurement_node, pauli_output)

print("\\nPauli measurement demo graph:")
# Define Pauli measurement graph structure
pauli_nodes = ["input", "x_meas", "y_meas", "z_meas", "output"]
pauli_edges = [
("input", "x_meas"),
("x_meas", "y_meas"),
("y_meas", "z_meas"),
("z_meas", "output"),
]
pauli_inputs = ["input"]
pauli_outputs = ["output"]

# Define Pauli measurement bases
pauli_meas_bases = {
"input": AxisMeasBasis(Axis.X, Sign.PLUS), # X+
"x_meas": AxisMeasBasis(Axis.X, Sign.PLUS), # X+: XY plane, θ=0
"y_meas": AxisMeasBasis(Axis.Y, Sign.PLUS), # Y+: YZ plane, θ=π/2
"z_meas": AxisMeasBasis(Axis.Z, Sign.MINUS), # Z-: XZ plane, θ=π
}

# Create Pauli measurement graph state from structure
pauli_demo_graph, pauli_node_map = GraphState.from_graph(
nodes=pauli_nodes,
edges=pauli_edges,
inputs=pauli_inputs,
outputs=pauli_outputs,
meas_bases=pauli_meas_bases,
)

print("\nPauli measurement demo graph:")
print(f"Input nodes: {list(pauli_demo_graph.input_node_indices.keys())}")
print(f"Output nodes: {list(pauli_demo_graph.output_node_indices.keys())}")
print("Pauli measurement nodes (will show bordered patterns):")
Expand Down
180 changes: 177 additions & 3 deletions graphqomb/graphstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
import itertools
import operator
from abc import ABC
from typing import TYPE_CHECKING, NamedTuple
from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Set as AbstractSet
from typing import TYPE_CHECKING, NamedTuple, TypeVar

import typing_extensions

from graphqomb.common import MeasBasis, Plane, PlannerMeasBasis
from graphqomb.euler import update_lc_basis, update_lc_lc

if TYPE_CHECKING:
from collections.abc import Set as AbstractSet

from graphqomb.euler import LocalClifford

NodeT = TypeVar("NodeT", bound=Hashable)


class BaseGraphState(ABC):
"""Abstract base class for Graph State."""
Expand Down Expand Up @@ -618,6 +620,178 @@ def _expand_output_local_cliffords(self) -> dict[int, LocalCliffordExpansion]:

return node_index_addition_map

@classmethod
def from_graph( # noqa: C901, PLR0912
cls,
nodes: Iterable[NodeT],
edges: Iterable[tuple[NodeT, NodeT]],
inputs: Sequence[NodeT] | None = None,
outputs: Sequence[NodeT] | None = None,
meas_bases: Mapping[NodeT, MeasBasis] | None = None,
) -> tuple[GraphState, dict[NodeT, int]]:
r"""Create a graph state from nodes and edges with arbitrary node types.

This factory method allows creating a graph state from any hashable node type
(e.g., strings, tuples, custom objects). The method internally maps external
node identifiers to integer indices used by GraphState.

Parameters
----------
nodes : `collections.abc.Iterable`\[NodeT\]
Nodes to add to the graph. Can be any hashable type.
edges : `collections.abc.Iterable`\[`tuple`\[NodeT, NodeT\]\]
Edges as pairs of node identifiers.
inputs : `collections.abc.Sequence`\[NodeT\] | `None`, optional
Input nodes in order. Qubit indices are assigned sequentially (0, 1, 2, ...).
Default is None (no inputs).
outputs : `collections.abc.Sequence`\[NodeT\] | `None`, optional
Output nodes in order. Qubit indices are assigned sequentially (0, 1, 2, ...).
Default is None (no outputs).
meas_bases : `collections.abc.Mapping`\[NodeT, `MeasBasis`\] | `None`, optional
Measurement bases for nodes. Nodes not specified can be set later.
Default is None (no bases assigned initially).

Returns
-------
`tuple`\[`GraphState`, `dict`\[NodeT, `int`\]\]
- Created GraphState instance
- Mapping from external node IDs to internal integer indices

Raises
------
ValueError
If duplicate nodes, invalid edges, or invalid input/output nodes.
"""
# Convert nodes to list to preserve order
nodes_list = list(nodes)

# Check for duplicate nodes
if len(nodes_list) != len(set(nodes_list)):
msg = "Duplicate nodes in input"
raise ValueError(msg)

node_set = set(nodes_list)

# Validate inputs
if inputs is not None:
for input_node in inputs:
if input_node not in node_set:
msg = f"Input node {input_node} not in nodes collection"
raise ValueError(msg)

# Validate outputs
if outputs is not None:
for output_node in outputs:
if output_node not in node_set:
msg = f"Output node {output_node} not in nodes collection"
raise ValueError(msg)

# Convert edges to list for validation
edges_list = list(edges)

# Validate edges
for node1, node2 in edges_list:
if node1 not in node_set:
msg = f"Edge references non-existent node {node1}"
raise ValueError(msg)
if node2 not in node_set:
msg = f"Edge references non-existent node {node2}"
raise ValueError(msg)

# Create GraphState instance
graph_state = cls()

# Add nodes and create node mapping
node_map: dict[NodeT, int] = {}
for node in nodes_list:
new_node = graph_state.add_physical_node()
node_map[node] = new_node

# Add edges
for node1, node2 in edges_list:
idx1 = node_map[node1]
idx2 = node_map[node2]
graph_state.add_physical_edge(idx1, idx2)

# Register inputs with sequential qubit indices
if inputs is not None:
for q_index, input_node in enumerate(inputs):
graph_state.register_input(node_map[input_node], q_index)

# Register outputs with sequential qubit indices
if outputs is not None:
for q_index, output_node in enumerate(outputs):
graph_state.register_output(node_map[output_node], q_index)

# Assign measurement bases
if meas_bases is not None:
for node, meas_basis in meas_bases.items():
if node in node_set:
graph_state.assign_meas_basis(node_map[node], meas_basis)

return graph_state, node_map

@classmethod
def from_base_graph_state(
cls,
base: BaseGraphState,
copy_local_cliffords: bool = True,
) -> tuple[GraphState, dict[int, int]]:
r"""Create a new GraphState from an existing BaseGraphState instance.

This method creates a complete copy of the graph structure, including nodes,
edges, input/output registrations, and measurement bases. Useful for creating
mutable copies or converting between GraphState implementations.

Parameters
----------
base : `BaseGraphState`
The source graph state to copy from.
copy_local_cliffords : `bool`, optional
Whether to copy local Clifford operators if the source is a GraphState.
If True and the source has local Cliffords, they are copied.
If False, local Cliffords are not copied (canonical form only).
Default is True.

Returns
-------
`tuple`\[`GraphState`, `dict`\[`int`, `int`\]\]
- Created GraphState instance
- Mapping from source node indices to new node indices
"""
# Create new GraphState instance
graph_state = cls()

# Create node mapping
node_map: dict[int, int] = {}
for node in base.physical_nodes:
new_node = graph_state.add_physical_node()
node_map[node] = new_node

# Add edges using node mapping
for node1, node2 in base.physical_edges:
graph_state.add_physical_edge(node_map[node1], node_map[node2])

# Register inputs with same qubit indices
for input_node, q_index in base.input_node_indices.items():
graph_state.register_input(node_map[input_node], q_index)

# Register outputs with same qubit indices
for output_node, q_index in base.output_node_indices.items():
graph_state.register_output(node_map[output_node], q_index)

# Copy measurement bases
for node, meas_basis in base.meas_bases.items():
graph_state.assign_meas_basis(node_map[node], meas_basis)

# Copy local Clifford operators if requested and source is GraphState
if copy_local_cliffords and isinstance(base, GraphState):
for node, lc in base.local_cliffords.items():
# Access private attribute to copy local cliffords
graph_state.apply_local_clifford(node_map[node], lc)

return graph_state, node_map


class LocalCliffordExpansion(NamedTuple):
"""Local Clifford expansion map for each input/output node."""
Expand Down
Loading