Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Performance

- **Pauli Frame**: Optimized `_collect_dependent_chain` method with memoization and caching
- Added Pauli axis cache to avoid redundant basis computations
- Implemented chain memoization cache to prevent recalculating dependent chains
- Optimized set operations for better performance in large graph states

### Tests

- **Pauli Frame**: Added comprehensive test suite for PauliFrame module
- Added tests for basic methods (x_flip, z_flip, meas_flip, children, parents)
- Added tests for Pauli axis cache initialization and chain cache memoization
- Added tests for dependent chain collection across X, Y, Z measurement axes
- Added tests for detector groups and logical observables
- Improved test coverage from 77.78% to 97% for pauli_frame.py

## [0.1.1] - 2025-10-23

### Added
Expand Down
46 changes: 35 additions & 11 deletions graphqomb/pauli_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class PauliFrame:
Current Z Pauli state for each node
parity_check_group : `list`\[`set`\[`int`\]\]
Parity check group for FTQC
inv_xflow : `dict`\[`int`, `int`\]
inv_xflow : `dict`\[`int`, `set`\[`int`\]\]
Inverse X correction flow for each measurement flip
inv_zflow : `dict`\[`int`, `int`\]
inv_zflow : `dict`\[`int`, `set`\[`int`\]\]
Inverse Z correction flow for each measurement flip
"""

Expand All @@ -50,6 +50,8 @@ class PauliFrame:
parity_check_group: list[set[int]]
inv_xflow: dict[int, set[int]]
inv_zflow: dict[int, set[int]]
_pauli_axis_cache: dict[int, Axis | None]
_chain_cache: dict[int, frozenset[int]]

def __init__(
self,
Expand Down Expand Up @@ -78,6 +80,15 @@ def __init__(
self.inv_zflow[target].add(node)
self.inv_zflow[node] -= {node}

# Pre-compute Pauli axes for performance optimization
# Only cache nodes that have measurement bases
# NOTE: if non-Pauli measurements are involved, the stim_compile func will error out earlier
self._pauli_axis_cache = {
node: determine_pauli_axis(meas_basis) for node, meas_basis in graphstate.meas_bases.items()
}
# Cache for memoization of dependent chains
self._chain_cache = {}

def x_flip(self, node: int) -> None:
"""Flip the X Pauli mask for the given node.

Expand Down Expand Up @@ -196,31 +207,44 @@ def _collect_dependent_chain(self, node: int) -> set[int]:
ValueError
If an unexpected output basis or measurement plane is encountered.
"""
# Check memoization cache
if node in self._chain_cache:
return set(self._chain_cache[node])

chain: set[int] = set()
untracked = {node}
tracked: set[int] = set()

while untracked:
current = untracked.pop()
chain ^= {current}

parents: set[int] = set()
# Optimized XOR operation: toggle membership
if current in chain:
chain.remove(current)
else:
chain.add(current)

# Use pre-computed Pauli axis from cache
axis = self._pauli_axis_cache[current]

# NOTE: might have to support plane instead of axis
axis = determine_pauli_axis(self.graphstate.meas_bases[current])
if axis == Axis.X:
parents = self.inv_zflow.get(current, set())
# Use defaultdict direct access (no need for .get with default)
parents = self.inv_zflow[current]
elif axis == Axis.Y:
parents = self.inv_xflow.get(current, set()) ^ self.inv_zflow.get(current, set())
# Optimized symmetric difference for Y axis
parents = self.inv_xflow[current].symmetric_difference(self.inv_zflow[current])
elif axis == Axis.Z:
parents = self.inv_xflow.get(current, set())
parents = self.inv_xflow[current]
else:
msg = f"Unexpected measurement axis: {axis}"
raise ValueError(msg)

for p in parents:
if p not in tracked:
untracked.add(p)
# Add untracked parents in bulk
untracked.update(p for p in parents if p not in tracked)
tracked.add(current)

# Store result in cache for future calls
self._chain_cache[node] = frozenset(chain)

return chain
Loading