Skip to content

Commit e24163d

Browse files
authored
Merge pull request #85 from qutech/feature/small_improvements
Small improvements
2 parents 2ebe601 + 7e724c2 commit e24163d

17 files changed

+408
-269
lines changed

filter_functions/basis.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,8 @@ def _full_from_partial(elems: Sequence, traceless: bool, labels: Sequence[str])
553553
traceless = elems.istraceless
554554
else:
555555
if traceless and not elems.istraceless:
556-
raise ValueError("The basis elements are not traceless (up to an identity element) " +
557-
"but a traceless basis was requested!")
556+
raise ValueError("The basis elements are not traceless (up to an identity element) "
557+
+ "but a traceless basis was requested!")
558558

559559
if labels is not None and len(labels) not in (len(elems), elems.d**2):
560560
raise ValueError(f'Got {len(labels)} labels but expected {len(elems)} or {elems.d**2}')
@@ -677,7 +677,8 @@ def expand(M: Union[ndarray, Basis], basis: Union[ndarray, Basis],
677677
678678
"""
679679

680-
def cast(arr): return arr.real if hermitian and basis.isherm else arr
680+
def cast(arr):
681+
return arr.real if hermitian and basis.isherm else arr
681682

682683
coefficients = cast(np.tensordot(M, basis, axes=[(-2, -1), (-1, -2)]))
683684
if not normalized:
@@ -724,7 +725,8 @@ def ggm_expand(M: Union[ndarray, Basis], traceless: bool = False,
724725
if M.shape[-1] != M.shape[-2]:
725726
raise ValueError('M should be square in its last two axes')
726727

727-
def cast(arr): return arr.real if hermitian else arr
728+
def cast(arr):
729+
return arr.real if hermitian else arr
728730

729731
# Squeeze out an extra dimension to be shape agnostic
730732
square = M.ndim < 3

filter_functions/gradient.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@
5555

5656
import numpy as np
5757
import opt_einsum as oe
58-
from opt_einsum.contract import ContractExpression
5958
from numpy import ndarray
59+
from opt_einsum.contract import ContractExpression
6060

6161
from . import numeric, superoperator, util
6262
from .basis import Basis
@@ -656,13 +656,13 @@ def infidelity_derivative(
656656
Section A: General, Atomic and Solid State Physics, 303(4), 249–252.
657657
https://doi.org/10.1016/S0375-9601(02)01272-0
658658
"""
659-
spectrum = numeric._parse_spectrum(spectrum, omega, range(len(pulse.n_opers)))
659+
spectrum = util.parse_spectrum(spectrum, omega, range(len(pulse.n_opers)))
660660
filter_function_deriv = pulse.get_filter_function_derivative(omega,
661661
control_identifiers,
662662
n_oper_identifiers,
663663
n_coeffs_deriv)
664664

665-
integrand = np.einsum('ao,atho->atho', spectrum, filter_function_deriv)
665+
integrand = np.einsum('...o,...tho->...tho', spectrum, filter_function_deriv)
666666
infid_deriv = util.integrate(integrand, omega) / (2*np.pi*pulse.d)
667667

668668
return infid_deriv

filter_functions/numeric.py

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -249,24 +249,6 @@ def _second_order_integral(E: ndarray, eigvals: ndarray, dt: float,
249249
return int_buf
250250

251251

252-
def _parse_spectrum(spectrum: Sequence, omega: Sequence, idx: Sequence) -> ndarray:
253-
spectrum = np.asanyarray(spectrum)
254-
error = 'Spectrum should be of shape {}, not {}.'
255-
shape = (len(idx),)*(spectrum.ndim - 1) + (len(omega),)
256-
if spectrum.shape != shape and spectrum.ndim <= 3:
257-
raise ValueError(error.format(shape, spectrum.shape))
258-
259-
if spectrum.ndim == 1:
260-
# As we broadcast over the noise operators
261-
spectrum = spectrum[None, ...]
262-
if spectrum.ndim == 3 and not np.allclose(spectrum, spectrum.conj().swapaxes(0, 1)):
263-
raise ValueError('Cross-spectra given but not Hermitian along first two axes')
264-
elif spectrum.ndim > 3:
265-
raise ValueError(f'Expected spectrum to have < 4 dimensions, not {spectrum.ndim}')
266-
267-
return spectrum
268-
269-
270252
def _get_integrand(
271253
spectrum: ndarray,
272254
omega: ndarray,
@@ -330,7 +312,7 @@ def _get_integrand(
330312
# Everything simpler if noise operators always on 2nd-to-last axes
331313
filter_function = np.moveaxis(filter_function, source=[-5, -4], destination=[-3, -2])
332314

333-
spectrum = _parse_spectrum(spectrum, omega, idx)
315+
spectrum = util.parse_spectrum(spectrum, omega, idx)
334316
if spectrum.ndim in (1, 2):
335317
if filter_function is not None:
336318
integrand = (filter_function[..., tuple(idx), tuple(idx), :]*spectrum)
@@ -342,17 +324,17 @@ def _get_integrand(
342324
# R is not None
343325
if which_pulse == 'correlations':
344326
if which_FF == 'fidelity':
345-
einsum_str = 'gako,ao,hako->ghao'
327+
einsum_str = 'g...ko,...o,h...ko->gh...o'
346328
else:
347329
# which_FF == 'generalized'
348-
einsum_str = 'gako,ao,halo->ghaklo'
330+
einsum_str = 'g...ko,...o,h...lo->gh...klo'
349331
else:
350332
# which_pulse == 'total'
351333
if which_FF == 'fidelity':
352-
einsum_str = 'ako,ao,ako->ao'
334+
einsum_str = '...ko,...o,...ko->...o'
353335
else:
354336
# which_FF == 'generalized'
355-
einsum_str = 'ako,ao,alo->aklo'
337+
einsum_str = '...ko,...o,...lo->...klo'
356338

357339
integrand = np.einsum(einsum_str,
358340
ctrl_left[..., idx, :, :], spectrum, ctrl_right[..., idx, :, :])
@@ -699,7 +681,8 @@ def calculate_control_matrix_from_atomic(
699681
control_matrix = np.zeros(control_matrix_atomic.shape, dtype=complex)
700682
for g in util.progressbar_range(n, show_progressbar=show_progressbar,
701683
desc='Calculating control matrix'):
702-
control_matrix[g] = expr(phases[g]*control_matrix_atomic[g], propagators_liouville[g])
684+
control_matrix[g] = expr(phases[g]*control_matrix_atomic[g], propagators_liouville[g],
685+
out=control_matrix[g])
703686

704687
return control_matrix
705688

@@ -1077,8 +1060,8 @@ def calculate_cumulant_function(
10771060
N, d = pulse.basis.shape[:2]
10781061
if spectrum is None and omega is None:
10791062
if decay_amplitudes is None or (frequency_shifts is None and second_order):
1080-
raise ValueError('Require either spectrum and frequencies or precomputed ' +
1081-
'decay amplitudes (frequency shifts)')
1063+
raise ValueError('Require either spectrum and frequencies or precomputed '
1064+
+ 'decay amplitudes (frequency shifts)')
10821065

10831066
if which == 'correlations' and second_order:
10841067
raise ValueError('Cannot compute correlation cumulant function for second order terms')
@@ -1251,8 +1234,8 @@ def calculate_decay_amplitudes(
12511234
# which == 'correlations'
12521235
if pulse.is_cached('omega'):
12531236
if not np.array_equal(pulse.omega, omega):
1254-
raise ValueError('Pulse correlation decay amplitudes requested but omega not ' +
1255-
'equal to cached frequencies.')
1237+
raise ValueError('Pulse correlation decay amplitudes requested but omega not '
1238+
+ 'equal to cached frequencies.')
12561239

12571240
if pulse.is_cached('filter_function_pc_gen'):
12581241
control_matrix = None
@@ -1819,8 +1802,8 @@ def error_transfer_matrix(
18191802
"""
18201803
if cumulant_function is None:
18211804
if pulse is None or spectrum is None or omega is None:
1822-
raise ValueError('Require either precomputed cumulant function ' +
1823-
'or pulse, spectrum, and omega as arguments.')
1805+
raise ValueError('Require either precomputed cumulant function '
1806+
+ 'or pulse, spectrum, and omega as arguments.')
18241807

18251808
cumulant_function = calculate_cumulant_function(pulse, spectrum, omega,
18261809
n_oper_identifiers, 'total', second_order,
@@ -2020,8 +2003,8 @@ def infidelity(
20202003
try:
20212004
omega_IR = omega.get('omega_IR', 2*np.pi/pulse.tau*1e-2)
20222005
except AttributeError:
2023-
raise TypeError('omega should be dictionary with parameters ' +
2024-
'when test_convergence == True.')
2006+
raise TypeError('omega should be dictionary with parameters '
2007+
+ 'when test_convergence == True.')
20252008

20262009
omega_UV = omega.get('omega_UV', 2*np.pi/pulse.tau*1e+2)
20272010
spacing = omega.get('spacing', 'linear')
@@ -2058,8 +2041,8 @@ def infidelity(
20582041
# but trace tensor plays a role, cf eq. (39). For traceless bases,
20592042
# the trace tensor term reduces to delta_ij.
20602043
traces = pulse.basis.four_element_traces
2061-
traces_diag = (sparse.diagonal(traces, axis1=2, axis2=3).sum(-1) -
2062-
sparse.diagonal(traces, axis1=1, axis2=3).sum(-1)).todense()
2044+
traces_diag = (sparse.diagonal(traces, axis1=2, axis2=3).sum(-1)
2045+
- sparse.diagonal(traces, axis1=1, axis2=3).sum(-1)).todense()
20632046

20642047
control_matrix = pulse.get_control_matrix(omega, show_progressbar, cache_intermediates)
20652048
filter_function = np.einsum('ako,blo,kl->abo',
@@ -2070,14 +2053,9 @@ def infidelity(
20702053
cache_intermediates=cache_intermediates)
20712054
else:
20722055
# which == 'correlations'
2073-
if not pulse.basis.istraceless:
2074-
warn('Calculating pulse correlation fidelities with non-' +
2075-
'traceless basis. The results will be off.')
2076-
2077-
if pulse.is_cached('omega'):
2078-
if not np.array_equal(pulse.omega, omega):
2079-
raise ValueError('Pulse correlation infidelities requested ' +
2080-
'but omega not equal to cached frequencies.')
2056+
if pulse.is_cached('omega') and not np.array_equal(pulse.omega, omega):
2057+
raise ValueError('Pulse correlation infidelities requested '
2058+
+ 'but omega not equal to cached frequencies.')
20812059

20822060
filter_function = pulse.get_pulse_correlation_filter_function()
20832061

@@ -2087,8 +2065,8 @@ def infidelity(
20872065

20882066
if return_smallness:
20892067
if spectrum.ndim > 2:
2090-
raise NotImplementedError('Smallness parameter only implemented ' +
2091-
'for uncorrelated noise sources')
2068+
raise NotImplementedError('Smallness parameter only implemented '
2069+
+ 'for uncorrelated noise sources')
20922070

20932071
T1 = util.integrate(spectrum, omega)/(2*np.pi)
20942072
T2 = (pulse.dt*pulse.n_coeffs[idx]).sum(axis=-1)**2

filter_functions/plotting.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@
5555
from numpy import ndarray
5656

5757
from . import numeric, util
58-
from .types import (Axes, Coefficients, Colormap, Figure, FigureAxes, FigureAxesLegend, FigureGrid,
59-
Grid, Operator, State)
58+
from .types import (Axes, Coefficients, Colormap, Cycler, Figure, FigureAxes, FigureAxesLegend,
59+
FigureGrid, Grid, Operator, State)
6060

6161
__all__ = ['plot_cumulant_function', 'plot_infidelity_convergence', 'plot_filter_function',
6262
'plot_pulse_correlation_filter_function', 'plot_pulse_train']
@@ -129,9 +129,7 @@ def init_bloch_sphere(**bloch_kwargs) -> qt.Bloch:
129129
return b
130130

131131

132-
@util.parse_optional_parameters(prop=('total', 'piecewise'))
133-
def get_states_from_prop(U: Sequence[Operator], psi0: Optional[State] = None,
134-
prop: str = 'total') -> ndarray:
132+
def get_states_from_prop(U: Sequence[Operator], psi0: Optional[State] = None) -> ndarray:
135133
r"""
136134
Get the the quantum state at time t from the propagator and the
137135
inital state:
@@ -140,31 +138,18 @@ def get_states_from_prop(U: Sequence[Operator], psi0: Optional[State] = None,
140138
141139
|\psi(t)\rangle = U(t, 0)|\psi(0)\rangle
142140
143-
If *prop* is 'piecewise', then it is assumed that *U* is the
144-
propagator of a piecewise-constant control:
145-
146-
.. math::
147-
|\psi(t)\rangle = \prod_{l=1}^n U(t_l, t_{l-1})|\psi(0)\rangle
148-
149-
with :math:`t_0\equiv 0` and :math:`t_n\equiv t`.
150-
151141
"""
152142
if psi0 is None:
153-
psi0 = np.c_[1:-1:-1] # |0>
154-
155-
psi0 = psi0.full() if hasattr(psi0, 'full') else psi0 # qutip.Qobj
156-
d = max(psi0.shape)
157-
states = np.empty((len(U), d, 1), dtype=complex)
158-
if prop == 'total':
159-
for j in range(len(U)):
160-
states[j] = U[j] @ psi0
161-
else:
162-
# prop == 'piecewise'
163-
states[0] = U[0] @ psi0
164-
for j in range(1, len(U)):
165-
states[j] = U[j] @ states[j-1]
143+
# default to |0>
144+
psi0 = np.c_[1:-1:-1]
145+
elif hasattr(psi0, 'full'):
146+
# qutip.Qobj
147+
psi0 = psi0.full()
148+
149+
if psi0.shape[-2:] != (2, 1):
150+
raise ValueError('Initial state should be shape (..., 2, 1)')
166151

167-
return states
152+
return U @ psi0
168153

169154

170155
def plot_bloch_vector_evolution(
@@ -289,7 +274,7 @@ def plot_pulse_train(
289274
c_oper_identifiers: Optional[Sequence[int]] = None,
290275
fig: Optional[Figure] = None,
291276
axes: Optional[Axes] = None,
292-
cycler: Optional['cycler.Cycler'] = None,
277+
cycler: Optional[Cycler] = None,
293278
plot_kw: Optional[dict] = {},
294279
subplot_kw: Optional[dict] = None,
295280
gridspec_kw: Optional[dict] = None,
@@ -380,7 +365,7 @@ def plot_filter_function(
380365
xscale: str = 'log',
381366
yscale: str = 'linear',
382367
omega_in_units_of_tau: bool = True,
383-
cycler: Optional['cycler.Cycler'] = None,
368+
cycler: Optional[Cycler] = None,
384369
plot_kw: dict = {},
385370
subplot_kw: Optional[dict] = None,
386371
gridspec_kw: Optional[dict] = None,
@@ -510,7 +495,7 @@ def plot_pulse_correlation_filter_function(
510495
xscale: str = 'log',
511496
yscale: str = 'linear',
512497
omega_in_units_of_tau: bool = True,
513-
cycler: Optional['cycler.Cycler'] = None,
498+
cycler: Optional[Cycler] = None,
514499
plot_kw: dict = {},
515500
subplot_kw: Optional[dict] = None,
516501
gridspec_kw: Optional[dict] = None,
@@ -731,7 +716,7 @@ def plot_cumulant_function(
731716
732717
Parameters
733718
----------
734-
pulse: 'PulseSequence'
719+
pulse: PulseSequence
735720
The pulse sequence.
736721
spectrum: ndarray
737722
The two-sided noise spectrum.
@@ -801,13 +786,15 @@ def plot_cumulant_function(
801786
n_oper_identifiers = [f'$B_{{{i}}}$' for i in range(len(n_oper_inds))]
802787
else:
803788
if len(n_oper_identifiers) != len(K):
804-
raise ValueError('Both precomputed cumulant function and n_oper_identifiers ' +
805-
f'given but not same len: {len(K)} != {len(n_oper_identifiers)}')
789+
raise ValueError(
790+
'Both precomputed cumulant function and n_oper_identifiers '
791+
+ f'given but not same len: {len(K)} != {len(n_oper_identifiers)}'
792+
)
806793

807794
else:
808795
if pulse is None or spectrum is None or omega is None:
809-
raise ValueError('Require either precomputed cumulant function ' +
810-
'or pulse, spectrum, and omega as arguments.')
796+
raise ValueError('Require either precomputed cumulant function '
797+
+ 'or pulse, spectrum, and omega as arguments.')
811798

812799
n_oper_inds = util.get_indices_from_identifiers(pulse.n_oper_identifiers,
813800
n_oper_identifiers)
@@ -852,8 +839,8 @@ def plot_cumulant_function(
852839
grid = axes_grid1.ImageGrid(fig, **grid_kw)
853840
else:
854841
if len(grid) != len(n_oper_inds):
855-
raise ValueError('Size of supplied ImageGrid instance does not ' +
856-
'match the number of n_oper_identifiers given!')
842+
raise ValueError('Size of supplied ImageGrid instance does not '
843+
+ 'match the number of n_oper_identifiers given!')
857844

858845
fig = grid[0].get_figure()
859846

0 commit comments

Comments
 (0)