Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
e2f7966
:white_check_mark: Add tests for remove_physical_node
Oct 20, 2024
4ea1bf9
:sparkles: Add remove_physical_node to GraphState
Oct 20, 2024
684043c
:art: Fix docs
Oct 21, 2024
8131ec7
:art: Fix measurement_action
Oct 21, 2024
bd14bfc
:art: Fix measurement_action in local_complement
Oct 21, 2024
d6067b0
:white_check_mark: Add tests for remove_clifford with restricted cases
Oct 21, 2024
c96d58c
:construction: Add remove_clifford only valid for restricted cases
Oct 21, 2024
d3ed3d0
:white_check_mark: Add tests for Clifford removal with 4 new cases
Oct 23, 2024
9214c5e
:white_check_mark: Add tests for exception in Clifford removal
Oct 23, 2024
61e6d46
:white_check_mark: Add tests for exception in Clifford removal
Oct 23, 2024
e09d175
:sparkles: Update Clifford removal
Oct 23, 2024
5829cf4
:art: Modified after mypy & pyright
Oct 23, 2024
445ba19
:recycle: Refactor remove_clifford
Oct 23, 2024
cf565b9
:art: Improve test_remove_clifford_fails_for_special_clifford_vertex
Oct 24, 2024
3360f36
:construction: Add necessary test case
Oct 24, 2024
a74fc3c
:sparkles: Add new case for Clifford removal
Oct 24, 2024
8d61a02
:memo: Improve docs
Oct 24, 2024
cbbf673
:green_heart: Enable CI on all PRs
Oct 24, 2024
80c7f77
:recycle: Refactor test_remove_clifford
Oct 24, 2024
e7f932f
:recycle: Refactor tests for clifford removal
Oct 27, 2024
a831752
:art: Improve readability
Oct 27, 2024
51855af
:recycle: Refactor initial graph preparation for tests
Oct 27, 2024
277f2a6
:art: Modify to pass mypy & pyright
Oct 27, 2024
6570847
:bug: Fix set_output
Oct 28, 2024
c36c81c
:bug: Fix set_meas_angle and set_meas_plane
Oct 28, 2024
e77e2be
:art: Apply is_clifford_angle
Oct 28, 2024
21609bd
:recycle: Refactor test_graphstate.py
Oct 28, 2024
b23009a
:bug: Fix local_complement and pivot on output nodes
Oct 28, 2024
7655eee
:bug: Fix bug caused by incorrect output node handling
Oct 28, 2024
c878454
:art: Fix to pass mypy & pyright
Oct 28, 2024
eb781cd
:bug: Fix test_graphstate.py
Nov 4, 2024
cc007f2
:sparkles: Add remove_cliffords
Nov 4, 2024
cc27587
:bug: Fix remove_cliffords
Nov 4, 2024
b357223
:art: Improve readability of remove_cliffords
Nov 4, 2024
04cd779
:memo: Add docstrings to _step*_action
Nov 8, 2024
6133cad
:zap: Improve performance
Nov 8, 2024
88422c8
:bug: Fix bug in pivot_1
Nov 8, 2024
988b95d
:see_no_evil: Ignore .DS_Store
Nov 8, 2024
c93c45a
:truck: Move ZXGraphState into zxgraphstate.py
Nov 8, 2024
8a99e44
:truck: Move ZXGraphState into zxgraphstate.py
Nov 8, 2024
09d2037
:art: Rename _is_removable_clifford into _needs_nop
Nov 12, 2024
16304c4
:bug: Fix bug in removing cliffords from a random graph
Nov 12, 2024
f325f77
:zap: Fix condition check in remove_cliffords
Nov 12, 2024
45eda3f
:bug: Fix measurement actions
Dec 2, 2024
b00690a
:see_no_evil: Fix .gitignore
Dec 11, 2024
d8ec235
:truck: Move zxgraphstate tests from graphstate tests
Dec 13, 2024
17ddd8a
Merge branch 'mf' into remove-clifford-vertices
Dec 13, 2024
7becd35
:recycle: Refactor to reflect meas_bases
Dec 13, 2024
9f0f18f
Merge branch 'mf' into remove-clifford-vertices
Dec 14, 2024
a41d7d8
:art: Apply ruff
Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 37 additions & 175 deletions graphix_zx/graphstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
This module provides:
- BaseGraphState: Abstract base class for Graph State.
- GraphState: Minimal implementation of Graph State.
- ZXGraphState: Graph State for the ZX-calculus.
- bipartite_edges: Return a set of edges for the complete bipartite graph between two sets of nodes.
"""

from __future__ import annotations
Expand All @@ -13,14 +11,10 @@
from itertools import product
from typing import TYPE_CHECKING

import numpy as np

from graphix_zx.common import MeasBasis, Plane, PlannerMeasBasis, default_meas_basis
from graphix_zx.common import MeasBasis, default_meas_basis
from graphix_zx.euler import update_lc_basis

if TYPE_CHECKING:
from typing import Callable

from graphix_zx.euler import LocalClifford


Expand Down Expand Up @@ -398,7 +392,7 @@ def local_cliffords(self) -> dict[int, LocalClifford]:
"""
return self.__local_cliffords

def check_meas_bases(self) -> None:
def check_meas_basis(self) -> None:
"""Check if the measurement basis is set for all physical nodes except output nodes.

Raises
Expand Down Expand Up @@ -483,6 +477,32 @@ def add_physical_edge(self, node1: int, node2: int) -> None:
self.__physical_edges[node1] |= {node2}
self.__physical_edges[node2] |= {node1}

def remove_physical_node(self, node: int) -> None:
"""Remove a physical node from the graph state.

Parameters
----------
node : int

Raises
------
ValueError
If the node does not exist.
"""
if node not in self.__physical_nodes:
msg = f"Node does not exist {node=}"
raise ValueError(msg)
self.ensure_node_exists(node)
self.__physical_nodes -= {node}
del self.__physical_edges[node]
self.__input_nodes -= {node}
self.__output_nodes -= {node}
self.__meas_bases.pop(node, None)
self.__q_indices.pop(node, None)
self.__local_cliffords.pop(node, None)
for neighbor in self.__physical_edges:
self.__physical_edges[neighbor] -= {node}

def remove_physical_edge(self, node1: int, node2: int) -> None:
"""Remove a physical edge from the graph state.

Expand Down Expand Up @@ -524,8 +544,17 @@ def set_output(self, node: int) -> None:
----------
node : int
node index

Raises
------
ValueError
1. If the node does not exist.
2. If the node has a measurement basis.
"""
self.ensure_node_exists(node)
if self.meas_bases.get(node) is not None:
msg = "Cannot set output node with measurement basis."
raise ValueError(msg)
self.__output_nodes |= {node}

def set_q_index(self, node: int, q_index: int = -1) -> None:
Expand Down Expand Up @@ -660,173 +689,6 @@ def append(self, other: BaseGraphState) -> None:
self.set_q_index(node, q_index)


class ZXGraphState(GraphState):
"""Graph State for the ZX-calculus.

Attributes
----------
input_nodes : set[int]
set of input nodes
output_nodes : set[int]
set of output nodes
physical_nodes : set[int]
set of physical nodes
physical_edges : dict[int, set[int]]
physical edges
meas_bases : dict[int, MeasBasis]
q_indices : dict[int, int]
qubit indices
local_cliffords : dict[int, LocalClifford]
local clifford operators
"""

def __init__(self) -> None:
super().__init__()

def _update_connections(self, rmv_edges: set[tuple[int, int]], new_edges: set[tuple[int, int]]) -> None:
for edge in rmv_edges:
self.remove_physical_edge(edge[0], edge[1])
for edge in new_edges:
self.add_physical_edge(edge[0], edge[1])

def _update_node_measurement(
self, measurement_action: dict[Plane, tuple[Plane, Callable[[float], float]]], v: int
) -> None:
new_plane, new_angle_func = measurement_action[self.meas_bases[v].plane]
if new_plane:
new_angle = new_angle_func(v) % (2.0 * np.pi)
self.set_meas_basis(v, PlannerMeasBasis(new_plane, new_angle))

def local_complement(self, node: int) -> None:
"""Local complement operation on the graph state: G*u.

Parameters
----------
node : int
node index

Raises
------
ValueError
If the node is an input node, an output node, or the graph is not a ZX-diagram.
"""
self.ensure_node_exists(node)
if node in self.input_nodes or node in self.output_nodes:
msg = "Cannot apply local complement to input node nor output node."
raise ValueError(msg)
self.check_meas_bases()

nbrs: set[int] = self.get_neighbors(node)
nbr_pairs = bipartite_edges(nbrs, nbrs)
new_edges = nbr_pairs - self.physical_edges
rmv_edges = self.physical_edges & nbr_pairs

self._update_connections(rmv_edges, new_edges)

# update node measurement
measurement_action = {
Plane.XY: (Plane.XZ, lambda v: 0.5 * np.pi - self.meas_bases[v].angle),
Plane.XZ: (Plane.XY, lambda v: self.meas_bases[v].angle - 0.5 * np.pi),
Plane.YZ: (Plane.YZ, lambda v: self.meas_bases[v].angle + 0.5 * np.pi),
}

self._update_node_measurement(measurement_action, node)

# update neighbors measurement
measurement_action = {
Plane.XY: (Plane.XY, lambda v: self.meas_bases[v].angle - 0.5 * np.pi),
Plane.XZ: (Plane.YZ, lambda v: self.meas_bases[v].angle),
Plane.YZ: (Plane.XZ, lambda v: -self.meas_bases[v].angle),
}

for v in nbrs:
self._update_node_measurement(measurement_action, v)

def _swap(self, node1: int, node2: int) -> None:
"""Swap nodes u and v in the graph state.

Parameters
----------
node1 : int
node index
node2 : int
node index
"""
node1_nbrs = self.get_neighbors(node1) - {node2}
node2_nbrs = self.get_neighbors(node2) - {node1}
nbr_b = node1_nbrs - node2_nbrs
nbr_c = node2_nbrs - node1_nbrs
for b in nbr_b:
self.remove_physical_edge(node1, b)
self.add_physical_edge(node2, b)
for c in nbr_c:
self.remove_physical_edge(node2, c)
self.add_physical_edge(node1, c)

def pivot(self, node1: int, node2: int) -> None:
"""Pivot operation on the graph state: G∧(uv) (= G*u*v*u = G*v*u*v) for neighboring nodes u and v.

In order to maintain the ZX-diagram simple, pi-spiders are shifted properly.

Parameters
----------
node1 : int
node index
node2 : int
node index

Raises
------
ValueError
If the nodes are input nodes, output nodes, or the graph is not a ZX-diagram.
"""
self.ensure_node_exists(node1)
self.ensure_node_exists(node2)
if node1 in self.input_nodes or node2 in self.input_nodes:
msg = "Cannot apply pivot to input node"
raise ValueError(msg)
if node1 in self.output_nodes or node2 in self.output_nodes:
msg = "Cannot apply pivot to output node"
raise ValueError(msg)
self.check_meas_bases()

node1_nbrs = self.get_neighbors(node1) - {node2}
node2_nbrs = self.get_neighbors(node2) - {node1}
nbr_a = node1_nbrs & node2_nbrs
nbr_b = node1_nbrs - node2_nbrs
nbr_c = node2_nbrs - node1_nbrs
nbr_pairs = [
bipartite_edges(nbr_a, nbr_b),
bipartite_edges(nbr_a, nbr_c),
bipartite_edges(nbr_b, nbr_c),
]
rmv_edges = set().union(*(p & self.physical_edges for p in nbr_pairs))
add_edges = set().union(*(p - self.physical_edges for p in nbr_pairs))

self._update_connections(rmv_edges, add_edges)
self._swap(node1, node2)

# update node1 and node2 measurement
measurement_action = {
Plane.XY: (Plane.YZ, lambda v: self.meas_bases[v].angle),
Plane.XZ: (Plane.XZ, lambda v: (0.5 * np.pi - self.meas_bases[v].angle)),
Plane.YZ: (Plane.XY, lambda v: self.meas_bases[v].angle),
}

for a in [node1, node2]:
self._update_node_measurement(measurement_action, a)

# update nodes measurement of nbr_a
measurement_action = {
Plane.XY: (Plane.XY, lambda v: (self.meas_bases[v].angle + np.pi)),
Plane.XZ: (Plane.YZ, lambda v: -self.meas_bases[v].angle),
Plane.YZ: (Plane.XZ, lambda v: -self.meas_bases[v].angle),
}

for w in nbr_a:
self._update_node_measurement(measurement_action, w)


def bipartite_edges(node_set1: set[int], node_set2: set[int]) -> set[tuple[int, int]]:
"""Return a set of edges for the complete bipartite graph between two sets of nodes.

Expand Down
Loading