Skip to content

Commit b79c7a3

Browse files
authored
Merge pull request #53 from qutech/feature/improve_gradient_performance
Improve gradient performance
2 parents 3659145 + 29f637d commit b79c7a3

File tree

12 files changed

+649
-577
lines changed

12 files changed

+649
-577
lines changed

filter_functions/basis.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,21 +173,11 @@ def __new__(cls, basis_array: Sequence, traceless: Optional[bool] = None,
173173
except AttributeError:
174174
pass
175175

176-
basis = np.empty((len(basis_array), *basis_array[0].shape), dtype=complex)
176+
basis = util.parse_operators(basis_array, 'basis_array')
177177
if basis.shape[0] > np.product(basis.shape[1:]):
178178
raise ValueError('Given overcomplete set of basis matrices. '
179179
'Not linearly independent.')
180180

181-
for i, elem in enumerate(basis_array):
182-
if isinstance(elem, ndarray): # numpy array
183-
basis[i] = elem
184-
elif hasattr(elem, 'full'): # qutip.Qobj
185-
basis[i] = elem.full()
186-
elif hasattr(elem, 'todense'): # sparse array
187-
basis[i] = elem.todense()
188-
else:
189-
raise TypeError('At least one element invalid type!')
190-
191181
basis = basis.view(cls)
192182
basis.btype = btype or 'Custom'
193183
basis.d = basis.shape[-1]

filter_functions/gradient.py

Lines changed: 380 additions & 369 deletions
Large diffs are not rendered by default.

filter_functions/numeric.py

Lines changed: 79 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
frequencies
6666
"""
6767
from collections import deque
68-
from itertools import accumulate, repeat
68+
from itertools import accumulate, repeat, zip_longest
6969
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
7070
from warnings import warn
7171

@@ -93,39 +93,49 @@ def _propagate_eigenvectors(propagators, eigvecs):
9393
return propagators.transpose(0, 2, 1).conj() @ eigvecs
9494

9595

96-
def _transform_noise_operators(n_coeffs, n_opers, eigvecs):
97-
r"""
98-
Transform noise operators into the eigenspaces spanned by eigvecs.
96+
def _transform_hamiltonian(eigvecs, opers, coeffs=None):
97+
r"""Transform a Hamiltonian into the eigenspaces spanned by eigvecs.
98+
9999
I.e., the following transformation is performed:
100100
101101
.. math::
102102
103-
B_\alpha\rightarrow s_\alpha^{(g)}V^{(g)}B_\alpha V^{(g)\dagger}
103+
s_\alpha^{(g)} B_\alpha\rightarrow
104+
s_\alpha^{(g)} V^{(g)}B_\alpha V^{(g)\dagger}
105+
106+
where :math:`s_\alpha^{(g)}` are coefficients of the operator
107+
:math:`B_\alpha`.
104108
105109
"""
106-
assert len(n_opers) == len(n_coeffs)
107-
n_opers_transformed = np.empty((len(n_opers), *eigvecs.shape), dtype=complex)
108-
for j, (n_coeff, n_oper) in enumerate(zip(n_coeffs, n_opers)):
109-
n_opers_transformed[j] = n_oper @ eigvecs
110-
n_opers_transformed[j] = eigvecs.conj().transpose(0, 2, 1) @ n_opers_transformed[j]
111-
n_opers_transformed[j] *= n_coeff[:, None, None]
110+
if coeffs is None:
111+
coeffs = []
112+
else:
113+
assert len(opers) == len(coeffs)
112114

113-
return n_opers_transformed
115+
opers_transformed = np.empty((len(opers), *eigvecs.shape), dtype=complex)
116+
for j, (coeff, oper) in enumerate(zip_longest(coeffs, opers, fillvalue=None)):
117+
opers_transformed[j] = _transform_by_unitary(eigvecs, oper, out=opers_transformed[j])
118+
if coeff is not None:
119+
opers_transformed[j] *= coeff[:, None, None]
114120

121+
return opers_transformed
115122

116-
def _transform_basis(basis, eigvecs_propagated, out=None):
117-
r"""
118-
Transform the basis into the eigenspace spanned by V propagated by Q
123+
124+
def _transform_by_unitary(unitary, oper, out=None):
125+
r"""Transform the operators by a unitary. Uses broadcasting.
119126
120127
I.e., the following transformation is performed:
121128
122129
.. math::
123130
124-
C_k\rightarrow Q_{g-1}V^{(g)\dagger}C_k V^{(g)}Q_{g-1}^\dagger.
131+
C_k\rightarrow U C_k U^\dagger.
125132
126133
"""
127-
out = np.matmul(basis, eigvecs_propagated, out=out)
128-
out = np.matmul(eigvecs_propagated.conj().T, out, out=out)
134+
if out is None:
135+
out = np.empty(oper.shape, dtype=oper.dtype)
136+
137+
out = np.matmul(oper, unitary, out=out)
138+
out = np.matmul(unitary.conj().swapaxes(-1, -2), out, out=out)
129139
return out
130140

131141

@@ -147,7 +157,7 @@ def _first_order_integral(E: ndarray, eigvals: ndarray, dt: float,
147157
int_buf.imag = np.add.outer(E, dE, out=int_buf.imag)
148158

149159
# Catch zero-division
150-
mask = (int_buf.imag != 0)
160+
mask = (np.abs(int_buf.imag) > 1e-7)
151161
exp_buf = util.cexp(int_buf.imag*dt, out=exp_buf, where=mask)
152162
exp_buf = np.subtract(exp_buf, 1, out=exp_buf, where=mask)
153163
int_buf = np.divide(exp_buf, int_buf, out=int_buf, where=mask)
@@ -463,7 +473,8 @@ def calculate_noise_operators_from_scratch(
463473
n_coeffs: Sequence[Coefficients],
464474
dt: Coefficients,
465475
t: Optional[Coefficients] = None,
466-
show_progressbar: bool = False
476+
show_progressbar: bool = False,
477+
cache_intermediates: bool = False
467478
) -> ndarray:
468479
r"""
469480
Calculate the noise operators in interaction picture from scratch.
@@ -563,30 +574,44 @@ def calculate_noise_operators_from_scratch(
563574
n_coeffs = np.asarray(n_coeffs)
564575

565576
# Precompute noise opers transformed to eigenbasis of each pulse
566-
# segment and Q^\dagger @ V
577+
# segment and V^\dagger @ Q
567578
eigvecs_propagated = _propagate_eigenvectors(eigvecs, propagators[:-1])
568-
n_opers_transformed = _transform_noise_operators(n_coeffs, n_opers, eigvecs)
579+
n_opers_transformed = _transform_hamiltonian(eigvecs, n_opers, n_coeffs)
569580

570581
# Allocate memory
571582
exp_buf, int_buf = np.empty((2, len(omega), d, d), dtype=complex)
572-
intermediate = np.empty((len(omega), len(n_opers), d, d), dtype=complex)
573583
noise_operators = np.zeros((len(omega), len(n_opers), d, d), dtype=complex)
574584

585+
if cache_intermediates:
586+
sum_cache = np.empty((len(dt), len(omega), len(n_opers), d, d), dtype=complex)
587+
else:
588+
sum_buf = np.empty((len(omega), len(n_opers), d, d), dtype=complex)
589+
575590
# Set up reusable expressions
576591
expr_1 = oe.contract_expression('akl,okl->oakl',
577592
n_opers_transformed[:, 0].shape, int_buf.shape)
578593
expr_2 = oe.contract_expression('ji,...jk,kl',
579-
eigvecs_propagated[0].shape, intermediate.shape,
594+
eigvecs_propagated[0].shape, (len(omega), len(n_opers), d, d),
580595
eigvecs_propagated[0].shape, optimize=[(0, 1), (0, 1)])
581596

582597
for g in util.progressbar_range(len(dt), show_progressbar=show_progressbar,
583598
desc='Calculating noise operators'):
599+
if cache_intermediates:
600+
# Assign references to the locations in the cache for the quantities
601+
# that should be stored
602+
sum_buf = sum_cache[g]
603+
584604
int_buf = _first_order_integral(omega, eigvals[g], dt[g], exp_buf, int_buf)
585-
intermediate = expr_1(n_opers_transformed[:, g],
586-
util.cexp(omega[:, None, None]*t[g])*int_buf, out=intermediate)
605+
sum_buf = expr_1(n_opers_transformed[:, g], util.cexp(omega*t[g])[:, None, None]*int_buf,
606+
out=sum_buf)
587607

588-
noise_operators += expr_2(eigvecs_propagated[g].conj(), intermediate,
589-
eigvecs_propagated[g])
608+
noise_operators += expr_2(eigvecs_propagated[g].conj(), sum_buf, eigvecs_propagated[g],
609+
out=sum_buf)
610+
611+
if cache_intermediates:
612+
intermediates = dict(n_opers_transformed=n_opers_transformed,
613+
noise_operators_step=sum_cache)
614+
return noise_operators, intermediates
590615

591616
return noise_operators
592617

@@ -715,7 +740,7 @@ def calculate_control_matrix_from_scratch(
715740
cache_intermediates: bool, optional
716741
Keep and return intermediate terms
717742
:math:`\mathcal{G}^{(g)}(\omega)` of the sum so that
718-
:math:`\mathcal{R}(\omega)=\sum_g\mathcal{G}^{(g)}(\omega)`.
743+
:math:`\mathcal{B}(\omega)=\sum_g\mathcal{G}^{(g)}(\omega)`.
719744
Otherwise the sum is performed in-place.
720745
out: ndarray, optional
721746
A location into which the result is stored. See
@@ -771,51 +796,58 @@ def calculate_control_matrix_from_scratch(
771796
# Precompute noise opers transformed to eigenbasis of each pulse segment
772797
# and Q^\dagger @ V
773798
eigvecs_propagated = _propagate_eigenvectors(propagators[:-1], eigvecs)
774-
n_opers_transformed = _transform_noise_operators(n_coeffs, n_opers, eigvecs)
799+
n_opers_transformed = _transform_hamiltonian(eigvecs, n_opers, n_coeffs)
775800

776801
# Allocate result and buffers for intermediate arrays
777-
exp_buf, int_buf = np.empty((2, len(omega), d, d), dtype=complex)
778-
802+
exp_buf = np.empty((len(omega), d, d), dtype=complex)
779803
if out is None:
780-
control_matrix = np.zeros((len(n_opers), len(basis), len(omega)), dtype=complex)
781-
else:
782-
control_matrix = out
804+
out = np.zeros((len(n_opers), len(basis), len(omega)), dtype=complex)
783805

784806
if cache_intermediates:
785807
basis_transformed_cache = np.empty((len(dt), *basis.shape), dtype=complex)
808+
phase_factors_cache = np.empty((len(dt), len(omega)), dtype=complex)
809+
int_cache = np.empty((len(dt), len(omega), d, d), dtype=complex)
786810
sum_cache = np.empty((len(dt), len(n_opers), len(basis), len(omega)), dtype=complex)
787811
else:
788812
basis_transformed = np.empty(basis.shape, dtype=complex)
813+
phase_factors = np.empty(len(omega), dtype=complex)
814+
int_buf = np.empty((len(omega), d, d), dtype=complex)
789815
sum_buf = np.empty((len(n_opers), len(basis), len(omega)), dtype=complex)
790816

791817
# Optimize the contraction path dynamically since it differs for different
792818
# values of d
793819
expr = oe.contract_expression('o,jmn,omn,knm->jko',
794820
omega.shape, n_opers_transformed[:, 0].shape,
795-
int_buf.shape, basis.shape, optimize=True)
821+
exp_buf.shape, basis.shape, optimize=True)
796822
for g in util.progressbar_range(len(dt), show_progressbar=show_progressbar,
797823
desc='Calculating control matrix'):
798824

799825
if cache_intermediates:
800826
# Assign references to the locations in the cache for the quantities
801827
# that should be stored
802828
basis_transformed = basis_transformed_cache[g]
829+
phase_factors = phase_factors_cache[g]
830+
int_buf = int_cache[g]
803831
sum_buf = sum_cache[g]
804832

805-
basis_transformed = _transform_basis(basis, eigvecs_propagated[g], out=basis_transformed)
833+
basis_transformed = _transform_by_unitary(eigvecs_propagated[g], basis,
834+
out=basis_transformed)
835+
phase_factors = util.cexp(omega*t[g], out=phase_factors)
806836
int_buf = _first_order_integral(omega, eigvals[g], dt[g], exp_buf, int_buf)
807-
sum_buf = expr(util.cexp(omega*t[g]), n_opers_transformed[:, g], int_buf,
837+
sum_buf = expr(phase_factors, n_opers_transformed[:, g], int_buf,
808838
basis_transformed, out=sum_buf)
809839

810-
control_matrix += sum_buf
840+
out += sum_buf
811841

812842
if cache_intermediates:
813843
intermediates = dict(n_opers_transformed=n_opers_transformed,
814844
basis_transformed=basis_transformed_cache,
845+
phase_factors=phase_factors_cache,
846+
first_order_integral=int_cache,
815847
control_matrix_step=sum_cache)
816-
return control_matrix, intermediates
848+
return out, intermediates
817849

818-
return control_matrix
850+
return out
819851

820852

821853
def calculate_control_matrix_periodic(phases: ndarray, control_matrix: ndarray,
@@ -1180,7 +1212,7 @@ def calculate_decay_amplitudes(
11801212
"""
11811213
# TODO: Replace infidelity() by this?
11821214
# Noise operator indices
1183-
idx = util.get_indices_from_identifiers(pulse, n_oper_identifiers, 'noise')
1215+
idx = util.get_indices_from_identifiers(pulse.n_oper_identifiers, n_oper_identifiers)
11841216
if which == 'total':
11851217
# Faster to use filter function instead of control matrix
11861218
if pulse.is_cached('filter_function_gen'):
@@ -1297,7 +1329,7 @@ def calculate_frequency_shifts(
12971329
pulse_sequence.concatenate: Concatenate ``PulseSequence`` objects.
12981330
calculate_pulse_correlation_filter_function
12991331
"""
1300-
idx = util.get_indices_from_identifiers(pulse, n_oper_identifiers, 'noise')
1332+
idx = util.get_indices_from_identifiers(pulse.n_oper_identifiers, n_oper_identifiers)
13011333
filter_function_2 = pulse.get_filter_function(omega, order=2,
13021334
show_progressbar=show_progressbar)
13031335
integrand = _get_integrand(spectrum, omega, idx, which_pulse='total', which_FF='generalized',
@@ -1478,7 +1510,7 @@ def calculate_second_order_filter_function(
14781510
t = np.concatenate(([0], np.asarray(dt).cumsum()))
14791511
# Cheap to precompute as these don't use a lot of memory
14801512
eigvecs_propagated = _propagate_eigenvectors(propagators[:-1], eigvecs)
1481-
n_opers_transformed = _transform_noise_operators(n_coeffs, n_opers, eigvecs)
1513+
n_opers_transformed = _transform_hamiltonian(eigvecs, n_opers, n_coeffs)
14821514
# These are populated anew during every iteration, so there is no need
14831515
# to keep every time step
14841516
basis_transformed = np.empty(basis.shape, dtype=complex)
@@ -1494,8 +1526,8 @@ def calculate_second_order_filter_function(
14941526
for g in util.progressbar_range(len(dt), show_progressbar=show_progressbar,
14951527
desc='Calculating second order FF'):
14961528
if not intermediates:
1497-
basis_transformed = _transform_basis(basis, eigvecs_propagated[g],
1498-
out=basis_transformed)
1529+
basis_transformed = _transform_by_unitary(eigvecs_propagated[g], basis,
1530+
out=basis_transformed)
14991531
# Need to compute G^(g) since no cache given. First initialize
15001532
# buffer to zero. There is a probably lots of overhead computing
15011533
# this individually for every time step.
@@ -1930,7 +1962,7 @@ def infidelity(pulse: 'PulseSequence', spectrum: Union[Coefficients, Callable],
19301962
plotting.plot_infidelity_convergence: Convenience function to plot results.
19311963
"""
19321964
# Noise operator indices
1933-
idx = util.get_indices_from_identifiers(pulse, n_oper_identifiers, 'noise')
1965+
idx = util.get_indices_from_identifiers(pulse.n_oper_identifiers, n_oper_identifiers)
19341966

19351967
if test_convergence:
19361968
if not callable(spectrum):

filter_functions/plotting.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def plot_pulse_train(
315315
ValueError
316316
If an invalid number of c_oper_labels were given
317317
"""
318-
c_oper_inds = util.get_indices_from_identifiers(pulse, c_oper_identifiers, 'control')
318+
c_oper_inds = util.get_indices_from_identifiers(pulse.c_oper_identifiers, c_oper_identifiers)
319319
c_oper_identifiers = pulse.c_oper_identifiers[c_oper_inds]
320320

321321
if fig is None and axes is None:
@@ -425,7 +425,7 @@ def plot_filter_function(
425425
else:
426426
omega = pulse.omega
427427

428-
n_oper_inds = util.get_indices_from_identifiers(pulse, n_oper_identifiers, 'noise')
428+
n_oper_inds = util.get_indices_from_identifiers(pulse.n_oper_identifiers, n_oper_identifiers)
429429
n_oper_identifiers = pulse.n_oper_identifiers[n_oper_inds]
430430

431431
if fig is None and axes is None:
@@ -548,7 +548,7 @@ def plot_pulse_correlation_filter_function(
548548
If the pulse correlation filter function was not computed during
549549
concatenation.
550550
"""
551-
n_oper_inds = util.get_indices_from_identifiers(pulse, n_oper_identifiers, 'noise')
551+
n_oper_inds = util.get_indices_from_identifiers(pulse.n_oper_identifiers, n_oper_identifiers)
552552
n_oper_identifiers = pulse.n_oper_identifiers[n_oper_inds]
553553
diag_idx = np.arange(len(pulse.n_opers))
554554
F_pc = pulse.get_pulse_correlation_filter_function()
@@ -777,7 +777,8 @@ def plot_cumulant_function(
777777
raise ValueError('Require either precomputed cumulant function ' +
778778
'or pulse, spectrum, and omega as arguments.')
779779

780-
n_oper_inds = util.get_indices_from_identifiers(pulse, n_oper_identifiers, 'noise')
780+
n_oper_inds = util.get_indices_from_identifiers(pulse.n_oper_identifiers,
781+
n_oper_identifiers)
781782
n_oper_identifiers = pulse.n_oper_identifiers[n_oper_inds]
782783
K = numeric.calculate_cumulant_function(pulse, spectrum, omega, n_oper_identifiers,
783784
'total', second_order)

0 commit comments

Comments
 (0)