Skip to content

Commit

Permalink
Fix typing in numpy function arguments (#5657)
Browse files Browse the repository at this point in the history
- Assure mypy that np.isclose receives a concrete scalar argument
- Turn off type check for TParamValComplex used in numpy expressions
  • Loading branch information
pavoljuhas authored Jul 6, 2022
1 parent 3e34f5d commit 7a845a5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
10 changes: 5 additions & 5 deletions cirq-core/cirq/ops/global_phase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A no-qubit global phase operation."""
from typing import AbstractSet, Any, Dict, Sequence, Tuple, TYPE_CHECKING, Union

from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union

import numpy as np
import sympy

import cirq
from cirq import value, protocols
from cirq.ops import raw_types
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
import cirq


@value.value_equality(approximate=True)
class GlobalPhaseGate(raw_types.Gate):
Expand Down Expand Up @@ -57,7 +56,8 @@ def _apply_unitary_(
) -> Union[np.ndarray, NotImplementedType]:
if not self._has_unitary_():
return NotImplemented
args.target_tensor *= self.coefficient
assert not cirq.is_parameterized(self)
args.target_tensor *= cast(np.generic, self.coefficient)
return args.target_tensor

def _has_stabilizer_effect_(self) -> bool:
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/ops/pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import numpy as np
import sympy

import cirq
from cirq import value, protocols, linalg, qis
from cirq._doc import document
from cirq._import import LazyLoader
Expand Down Expand Up @@ -498,9 +499,15 @@ def matrix(self, qubits: Optional[Iterable[TKey]] = None) -> np.ndarray:
in which the matrix representation of the Pauli string is to
be computed. Qubits absent from `self.qubits` are acted on by
the identity. Defaults to `self.qubits`.
Raises:
NotImplementedError: If this PauliString is parameterized.
"""
qubits = self.qubits if qubits is None else qubits
factors = [self.get(q, default=identity.I) for q in qubits]
if cirq.is_parameterized(self):
raise NotImplementedError('Cannot express as matrix when parameterized')
assert isinstance(self.coefficient, complex)
return linalg.kron(self.coefficient, *[protocols.unitary(f) for f in factors])

def _has_unitary_(self) -> bool:
Expand All @@ -516,6 +523,7 @@ def _unitary_(self) -> Optional[np.ndarray]:
def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs'):
if not self._has_unitary_():
return None
assert isinstance(self.coefficient, complex)
if self.coefficient != 1:
args.target_tensor *= self.coefficient
return protocols.apply_unitaries([self[q].on(q) for q in self.qubits], self.qubits, args)
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,8 @@ def test_parameterization():
pst.expectation_from_state_vector(np.array([]), {})
with pytest.raises(NotImplementedError, match='parameterized'):
pst.expectation_from_density_matrix(np.array([]), {})
with pytest.raises(NotImplementedError, match='as matrix when parameterized'):
pst.matrix()
assert pst**1 == pst
assert pst**-1 == pst.with_coefficient(1.0 / t)
assert (-pst) ** 1 == -pst
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/phased_iswap_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""ISWAPPowGate conjugated by tensor product Rz(phi) and Rz(-phi)."""

from typing import AbstractSet, Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import AbstractSet, Any, cast, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import sympy
Expand Down Expand Up @@ -173,7 +173,7 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
return NotImplemented
expansion = protocols.pauli_expansion(self._iswap)
assert set(expansion.keys()).issubset({'II', 'XX', 'YY', 'ZZ'})
assert np.isclose(expansion['XX'], expansion['YY'])
assert np.isclose(cast(np.generic, expansion['XX']), cast(np.generic, expansion['YY']))

v = (expansion['XX'] + expansion['YY']) / 2
phase_angle = np.pi * self.phase_exponent
Expand Down

0 comments on commit 7a845a5

Please sign in to comment.