Skip to content

Commit 26c755c

Browse files
authored
Merge pull request #102 from qutech/feature/small_improvements
Small improvements
2 parents 5a9219d + 06e672a commit 26c755c

File tree

9 files changed

+217
-188
lines changed

9 files changed

+217
-188
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ infidelity = ff.infidelity(hadamard, spectrum, omega)
6363
## Installation
6464
To install the package from PyPI, run `pip install filter_functions`. If you require the optional features provided by QuTiP (visualizing Bloch sphere trajectories), it is recommended to install QuTiP before by following the [instructions on their website](http://qutip.org/docs/latest/installation.html) rather than installing it through `pip`. To install the package from source run `python setup.py develop` to install using symlinks or `python setup.py install` without.
6565

66-
To install dependencies of optional extras (`ipynbname` for a fancy progress bar in Jupyter notebooks, `matplotlib` for plotting, `QuTiP` for Bloch sphere visualization), run `pip install -e .[extra]` where `extra` is one or more of `fancy_progressbar`, `plotting`, `bloch_sphere_visualization` from the root directory. To install all dependencies, including those needed to build the documentation and run the tests, use the extra `all`.
66+
To install dependencies of optional extras (`matplotlib` for plotting, `QuTiP` for Bloch sphere visualization), run `pip install -e .[extra]` where `extra` is one or more of `plotting`, `bloch_sphere_visualization` from the root directory. To install all dependencies, including those needed to build the documentation and run the tests, use the extra `all`.
6767

6868
## Documentation
6969
You can find the documentation on [Readthedocs](https://filter-functions.readthedocs.io/en/latest/). It is built from Jupyter notebooks that can also be run interactively and are located [here](doc/source/examples). The notebooks explain how to use the package and thus make sense to follow chronologically as a first step. Furthermore, there are also a few example scripts in the [examples](examples) folder.

filter_functions/basis.py

Lines changed: 116 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
Gell-Mann basis
4040
4141
"""
42-
42+
from functools import cached_property
4343
from itertools import product
4444
from typing import Optional, Sequence, Union
4545
from warnings import warn
@@ -197,12 +197,6 @@ def __array_finalize__(self, basis: 'Basis') -> None:
197197
self.btype = getattr(basis, 'btype', 'Custom')
198198
self.labels = getattr(basis, 'labels', [f'$C_{{{i}}}$' for i in range(len(basis))])
199199
self.d = getattr(basis, 'd', basis.shape[-1])
200-
self._sparse = None
201-
self._four_element_traces = None
202-
self._isherm = None
203-
self._isorthonorm = None
204-
self._istraceless = None
205-
self._iscomplete = None
206200
self._eps = np.finfo(complex).eps
207201
self._atol = self._eps*self.d**3
208202
self._rtol = 0
@@ -243,78 +237,77 @@ def _print_checks(self) -> None:
243237
for check in checks:
244238
print(check, ':\t', getattr(self, check))
245239

246-
@property
240+
def _invalidate_cached_properties(self):
241+
for attr in {'isherm', 'isnorm', 'isorthogonal', 'istraceless', 'iscomplete'}:
242+
try:
243+
delattr(self, attr)
244+
except AttributeError:
245+
pass
246+
247+
@cached_property
247248
def isherm(self) -> bool:
248249
"""Returns True if all basis elements are hermitian."""
249-
if self._isherm is None:
250-
self._isherm = (self.H == self)
251-
252-
return self._isherm
250+
return self.H == self
251+
252+
@cached_property
253+
def isnorm(self) -> bool:
254+
"""Returns True if all basis elements are normalized."""
255+
return self.normalize(copy=True) == self
256+
257+
@cached_property
258+
def isorthogonal(self) -> bool:
259+
"""Returns True if all basis elements are mutually orthogonal."""
260+
if self.ndim == 2 or len(self) == 1:
261+
return True
262+
263+
# The basis is orthogonal iff the matrix consisting of all d**2
264+
# elements written as d**2-dimensional column vectors is
265+
# orthogonal.
266+
dim = self.shape[0]
267+
U = self.reshape((dim, -1))
268+
actual = U.conj() @ U.T
269+
atol = self._eps*(self.d**2)**3
270+
mask = np.identity(dim, dtype=bool)
271+
return np.allclose(actual[..., ~mask].view(np.ndarray), 0, atol=atol, rtol=self._rtol)
253272

254273
@property
255274
def isorthonorm(self) -> bool:
256275
"""Returns True if basis is orthonormal."""
257-
if self._isorthonorm is None:
258-
# All the basis is orthonormal iff the matrix consisting of all
259-
# d**2 elements written as d**2-dimensional column vectors is
260-
# unitary.
261-
if self.ndim == 2 or len(self) == 1:
262-
# Only one basis element
263-
self._isorthonorm = True
264-
else:
265-
# Size of the result after multiplication
266-
dim = self.shape[0]
267-
U = self.reshape((dim, -1))
268-
actual = U.conj() @ U.T
269-
target = np.identity(dim)
270-
atol = self._eps*(self.d**2)**3
271-
self._isorthonorm = np.allclose(actual.view(np.ndarray), target,
272-
atol=atol, rtol=self._rtol)
276+
return self.isorthogonal and self.isnorm
273277

274-
return self._isorthonorm
275-
276-
@property
278+
@cached_property
277279
def istraceless(self) -> bool:
278280
"""
279281
Returns True if basis is traceless except for possibly the identity.
280282
"""
281-
if self._istraceless is None:
282-
trace = np.einsum('...jj', self)
283-
trace = util.remove_float_errors(trace, self.d**2)
284-
nonzero = np.atleast_1d(trace).nonzero()
285-
if nonzero[0].size == 0:
286-
self._istraceless = True
287-
elif nonzero[0].size == 1:
288-
# Single element has nonzero trace, check if (proportional to)
289-
# identity
290-
if self.ndim == 3:
291-
elem = self[nonzero][0].view(np.ndarray)
292-
else:
293-
elem = self.view(np.ndarray)
294-
offdiag_nonzero = elem[~np.eye(self.d, dtype=bool)].nonzero()
295-
diag_equal = np.diag(elem) == elem[0, 0]
296-
if diag_equal.all() and not offdiag_nonzero[0].any():
297-
# Element is (proportional to) the identity, this we define
298-
# as 'traceless' since a complete basis cannot have only
299-
# traceless elems.
300-
self._istraceless = True
301-
else:
302-
# Element not the identity, therefore not traceless
303-
self._istraceless = False
283+
trace = np.einsum('...jj', self)
284+
trace = util.remove_float_errors(trace, self.d**2)
285+
nonzero = np.atleast_1d(trace).nonzero()
286+
if nonzero[0].size == 0:
287+
return True
288+
elif nonzero[0].size == 1:
289+
# Single element has nonzero trace, check if (proportional to)
290+
# identity
291+
elem = self[nonzero][0].view(np.ndarray) if self.ndim == 3 else self.view(np.ndarray)
292+
offdiag_nonzero = elem[~np.eye(self.d, dtype=bool)].nonzero()
293+
diag_equal = np.diag(elem) == elem[0, 0]
294+
if diag_equal.all() and not offdiag_nonzero[0].any():
295+
# Element is (proportional to) the identity, this we define
296+
# as 'traceless' since a complete basis cannot have only
297+
# traceless elems.
298+
return True
304299
else:
305-
self._istraceless = False
306-
307-
return self._istraceless
300+
# Element not the identity, therefore not traceless
301+
return False
302+
else:
303+
return False
308304

309-
@property
305+
@cached_property
310306
def iscomplete(self) -> bool:
311307
"""Returns True if basis is complete."""
312-
if self._iscomplete is None:
313-
A = self.reshape(self.shape[0], -1)
314-
rank = np.linalg.matrix_rank(A)
315-
self._iscomplete = rank == self.d**2
316-
317-
return self._iscomplete
308+
A = self.reshape(self.shape[0], -1)
309+
rank = np.linalg.matrix_rank(A)
310+
return rank == self.d**2
318311

319312
@property
320313
def H(self) -> 'Basis':
@@ -329,48 +322,61 @@ def T(self) -> 'Basis':
329322

330323
return self
331324

332-
@property
325+
@cached_property
333326
def sparse(self) -> COO:
334327
"""Return the basis as a sparse COO array"""
335-
if self._sparse is None:
336-
self._sparse = COO.from_numpy(self)
337-
338-
return self._sparse
328+
return COO.from_numpy(self)
339329

340-
@property
330+
@cached_property
341331
def four_element_traces(self) -> COO:
342332
r"""
343333
Return all traces of the form
344334
:math:`\mathrm{tr}(C_i C_j C_k C_l)` as a sparse COO array for
345335
:math:`i,j,k,l > 0` (i.e. excluding the identity).
346336
"""
347-
if self._four_element_traces is None:
348-
# Most of the traces are zero, therefore store the result in a
349-
# sparse array. For GGM bases, which are inherently sparse, it
350-
# makes sense for any dimension to also calculate with sparse
351-
# arrays. For Pauli bases, which are very dense, this is not so
352-
# efficient but unavoidable for d > 12.
353-
path = [(0, 1), (0, 1), (0, 1)]
354-
if self.btype == 'Pauli' and self.d <= 12:
355-
# For d == 12, the result is ~270 MB.
356-
self._four_element_traces = COO.from_numpy(oe.contract('iab,jbc,kcd,lda->ijkl',
357-
*(self,)*4, optimize=path))
358-
else:
359-
self._four_element_traces = oe.contract('iab,jbc,kcd,lda->ijkl', *(self.sparse,)*4,
360-
backend='sparse', optimize=path)
337+
# Most of the traces are zero, therefore store the result in a
338+
# sparse array. For GGM bases, which are inherently sparse, it
339+
# makes sense for any dimension to also calculate with sparse
340+
# arrays. For Pauli bases, which are very dense, this is not so
341+
# efficient but unavoidable for d > 12.
342+
path = [(0, 1), (0, 1), (0, 1)]
343+
if self.btype == 'Pauli' and self.d <= 12:
344+
# For d == 12, the result is ~270 MB.
345+
return COO.from_numpy(oe.contract('iab,jbc,kcd,lda->ijkl', *(self,)*4, optimize=path))
346+
else:
347+
return oe.contract('iab,jbc,kcd,lda->ijkl', *(self.sparse,)*4, backend='sparse',
348+
optimize=path)
361349

362-
return self._four_element_traces
350+
def expand(self, M: np.ndarray, hermitian: bool = False, traceless: bool = False,
351+
tidyup: bool = False) -> np.ndarray:
352+
"""Expand matrices M in this basis.
363353
364-
@four_element_traces.setter
365-
def four_element_traces(self, traces):
366-
self._four_element_traces = traces
354+
Parameters
355+
----------
356+
M: array_like
357+
The square matrix (d, d) or array of square matrices (..., d, d)
358+
to be expanded in *basis*
359+
hermitian: bool (default: False)
360+
If M is hermitian along its last two axes, the result will be
361+
real.
362+
tidyup: bool {False}
363+
Whether to set values below the floating point eps to zero.
364+
365+
See Also
366+
--------
367+
expand : The function corresponding to this method.
368+
"""
369+
if self.btype == 'GGM' and self.iscomplete:
370+
return ggm_expand(M, traceless, hermitian, tidyup)
371+
return expand(M, self, self.isnorm, hermitian, tidyup)
367372

368373
def normalize(self, copy: bool = False) -> Union[None, 'Basis']:
369374
"""Normalize the basis."""
370375
if copy:
371376
return normalize(self)
372377

373378
self /= _norm(self)
379+
self._invalidate_cached_properties()
374380

375381
def tidyup(self, eps_scale: Optional[float] = None) -> None:
376382
"""Wraps util.remove_float_errors."""
@@ -382,6 +388,8 @@ def tidyup(self, eps_scale: Optional[float] = None) -> None:
382388
self.real[np.abs(self.real) <= atol] = 0
383389
self.imag[np.abs(self.imag) <= atol] = 0
384390

391+
self._invalidate_cached_properties()
392+
385393
@classmethod
386394
def pauli(cls, n: int) -> 'Basis':
387395
r"""
@@ -549,15 +557,14 @@ def _full_from_partial(elems: Sequence, traceless: bool, labels: Sequence[str])
549557
if not elems.isherm:
550558
warn("(Some) elems not hermitian! The resulting basis also won't be.")
551559

552-
if not elems.isorthonorm:
553-
raise ValueError("The basis elements are not orthonormal!")
560+
if not elems.isorthogonal:
561+
raise ValueError("The basis elements are not orthogonal!")
554562

555563
if traceless is None:
556564
traceless = elems.istraceless
557-
else:
558-
if traceless and not elems.istraceless:
559-
raise ValueError("The basis elements are not traceless (up to an identity element) "
560-
+ "but a traceless basis was requested!")
565+
elif traceless and not elems.istraceless:
566+
raise ValueError("The basis elements are not traceless (up to an identity element) "
567+
+ "but a traceless basis was requested!")
561568

562569
if labels is not None and len(labels) not in (len(elems), elems.d**2):
563570
raise ValueError(f'Got {len(labels)} labels but expected {len(elems)} or {elems.d**2}')
@@ -566,12 +573,13 @@ def _full_from_partial(elems: Sequence, traceless: bool, labels: Sequence[str])
566573
# properties hermiticity and orthonormality, and therefore also linear
567574
# combinations, ie basis expansions, of it will). Split off the identity so
568575
# that for traceless bases we can put it in the front.
576+
ggm = Basis.ggm(elems.d)
577+
coeffs = ggm.expand(elems, traceless=traceless, hermitian=elems.isherm, tidyup=True)
578+
569579
if traceless:
570-
Id, ggm = np.split(Basis.ggm(elems.d), [1])
571-
else:
572-
ggm = Basis.ggm(elems.d)
580+
Id, ggm = np.split(ggm, [1])
581+
coeffs = coeffs[..., 1:]
573582

574-
coeffs = expand(elems, ggm, hermitian=elems.isherm, tidyup=True)
575583
# Throw out coefficient vectors that are all zero (should only happen for
576584
# the identity)
577585
coeffs = coeffs[(coeffs != 0).any(axis=-1)]
@@ -636,7 +644,7 @@ def normalize(b: Basis) -> Basis:
636644
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
637645
638646
"""
639-
return (b/_norm(b)).squeeze().view(Basis)
647+
return (b/_norm(b)).squeeze().reshape(b.shape).view(Basis)
640648

641649

642650
def expand(M: Union[np.ndarray, Basis], basis: Union[np.ndarray, Basis],
@@ -691,7 +699,7 @@ def cast(arr):
691699

692700

693701
def ggm_expand(M: Union[np.ndarray, Basis], traceless: bool = False,
694-
hermitian: bool = False) -> np.ndarray:
702+
hermitian: bool = False, tidyup: bool = False) -> np.ndarray:
695703
r"""
696704
Expand the matrix *M* in a Generalized Gell-Mann basis [Bert08]_.
697705
This function makes use of the explicit construction prescription of
@@ -712,6 +720,8 @@ def ggm_expand(M: Union[np.ndarray, Basis], traceless: bool = False,
712720
hermitian: bool (default: False)
713721
If M is hermitian along its last two axes, the result will be
714722
real.
723+
tidyup: bool {False}
724+
Whether to set values below the floating point eps to zero.
715725
716726
Returns
717727
-------
@@ -759,7 +769,7 @@ def cast(arr):
759769
coeffs = np.zeros((*M.shape[:-2], d**2), dtype=float if hermitian else complex)
760770
if not traceless:
761771
# First element is proportional to the trace of M
762-
coeffs[..., 0] = cast(np.einsum('...jj', M))/np.sqrt(d)
772+
coeffs[..., 0] = cast(M.trace(0, -1, -2))/np.sqrt(d)
763773

764774
# Elements proportional to the symmetric GGMs
765775
coeffs[..., sym_rng] = cast(M[triu_idx] + M[tril_idx])/np.sqrt(2)
@@ -770,7 +780,11 @@ def cast(arr):
770780
- diag_rng*M[diag_idx_shifted])
771781
coeffs[..., diag_rng + 2*n_sym] /= cast(np.sqrt(diag_rng*(diag_rng + 1)))
772782

773-
return coeffs.squeeze() if square else coeffs
783+
if square:
784+
coeffs = coeffs.squeeze()
785+
if tidyup:
786+
coeffs = util.remove_float_errors(coeffs)
787+
return coeffs
774788

775789

776790
def equivalent_pauli_basis_elements(idx: Union[Sequence[int], int], N: int) -> np.ndarray:

0 commit comments

Comments
 (0)