From 23976e286cba1e87a69a19c71035a324f36b8e42 Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Thu, 24 Mar 2022 17:32:41 -0700 Subject: [PATCH] Fix broken caching in CliffordGate and add test (#5142) Looks like copy-pasta errors to me. Also update a test to being parameterized and complete. --- cirq-core/cirq/ops/clifford_gate.py | 12 +++---- cirq-core/cirq/ops/clifford_gate_test.py | 41 +++++++++++++++++------- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/cirq-core/cirq/ops/clifford_gate.py b/cirq-core/cirq/ops/clifford_gate.py index 553923b1524..91649fdc1aa 100644 --- a/cirq-core/cirq/ops/clifford_gate.py +++ b/cirq-core/cirq/ops/clifford_gate.py @@ -610,18 +610,18 @@ def I(cls): @property def X(cls): if getattr(cls, '_X', None) is None: - cls._Z = cls._generate_clifford_from_known_gate(1, pauli_gates.X) - return cls._Z + cls._X = cls._generate_clifford_from_known_gate(1, pauli_gates.X) + return cls._X @property def Y(cls): - if getattr(cls, '_X', None) is None: - cls._Z = cls._generate_clifford_from_known_gate(1, pauli_gates.Y) - return cls._Z + if getattr(cls, '_Y', None) is None: + cls._Y = cls._generate_clifford_from_known_gate(1, pauli_gates.Y) + return cls._Y @property def Z(cls): - if getattr(cls, '_X', None) is None: + if getattr(cls, '_Z', None) is None: cls._Z = cls._generate_clifford_from_known_gate(1, pauli_gates.Z) return cls._Z diff --git a/cirq-core/cirq/ops/clifford_gate_test.py b/cirq-core/cirq/ops/clifford_gate_test.py index f430f987d60..1a4424f0435 100644 --- a/cirq-core/cirq/ops/clifford_gate_test.py +++ b/cirq-core/cirq/ops/clifford_gate_test.py @@ -529,18 +529,26 @@ def test_text_diagram_info(gate, sym, exp): ) -def test_from_unitary(): - def _test(clifford_gate): - u = cirq.unitary(clifford_gate) - result_gate = cirq.SingleQubitCliffordGate.from_unitary(u) - assert result_gate == clifford_gate - - _test(cirq.SingleQubitCliffordGate.I) - _test(cirq.SingleQubitCliffordGate.H) - _test(cirq.SingleQubitCliffordGate.X) - _test(cirq.SingleQubitCliffordGate.Y) - _test(cirq.SingleQubitCliffordGate.Z) - _test(cirq.SingleQubitCliffordGate.X_nsqrt) +@pytest.mark.parametrize( + "clifford_gate", + ( + cirq.SingleQubitCliffordGate.I, + cirq.SingleQubitCliffordGate.H, + cirq.SingleQubitCliffordGate.X, + cirq.SingleQubitCliffordGate.Y, + cirq.SingleQubitCliffordGate.Z, + cirq.SingleQubitCliffordGate.X_sqrt, + cirq.SingleQubitCliffordGate.Y_sqrt, + cirq.SingleQubitCliffordGate.Z_sqrt, + cirq.SingleQubitCliffordGate.X_nsqrt, + cirq.SingleQubitCliffordGate.Y_nsqrt, + cirq.SingleQubitCliffordGate.Z_nsqrt, + ), +) +def test_from_unitary(clifford_gate): + u = cirq.unitary(clifford_gate) + result_gate = cirq.SingleQubitCliffordGate.from_unitary(u) + assert result_gate == clifford_gate def test_from_unitary_with_phase_shift(): @@ -612,6 +620,15 @@ def test_common_clifford_gate(clifford_gate, standard_gate): cirq.testing.assert_allclose_up_to_global_phase(u_c, u_s, atol=1e-8) +@pytest.mark.parametrize('clifford_gate_name', ("I", "X", "Y", "Z", "H", "S", "CNOT", "CZ", "SWAP")) +def test_common_clifford_gate_caching(clifford_gate_name): + cache_name = f"_{clifford_gate_name}" + delattr(cirq.CliffordGate, cache_name) + assert not hasattr(cirq.CliffordGate, cache_name) + _ = getattr(cirq.CliffordGate, clifford_gate_name) + assert hasattr(cirq.CliffordGate, cache_name) + + def test_multi_qubit_clifford_pow(): assert cirq.CliffordGate.X ** -1 == cirq.CliffordGate.X assert cirq.CliffordGate.H ** -1 == cirq.CliffordGate.H