Skip to content

Commit 8bf6159

Browse files
authored
Merge pull request #63 from qutech/hotfix/cache_intermediates_second_order
Fix a bug occuring when calculating second order with cached intermediates
2 parents 2d81b21 + 72a1a30 commit 8bf6159

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

filter_functions/numeric.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ def calculate_cumulant_function(
942942
frequency_shifts: Optional[ndarray] = None,
943943
show_progressbar: bool = False,
944944
memory_parsimonious: bool = False,
945-
cache_intermediates: bool = False
945+
cache_intermediates: Optional[bool] = None
946946
) -> ndarray:
947947
r"""Calculate the cumulant function :math:`\mathcal{K}(\tau)`.
948948
@@ -996,7 +996,7 @@ def calculate_cumulant_function(
996996
Keep and return intermediate terms of the calculation of the
997997
control matrix that can be reused in other computations (second
998998
order or gradients). Otherwise the sum is performed in-place.
999-
The default is False.
999+
Default is True if second_order=True, else False.
10001000
10011001
Returns
10021002
-------
@@ -1074,6 +1074,9 @@ def calculate_cumulant_function(
10741074
if which == 'correlations' and second_order:
10751075
raise ValueError('Cannot compute correlation cumulant function for second order terms')
10761076

1077+
if cache_intermediates is None:
1078+
cache_intermediates = second_order
1079+
10771080
if decay_amplitudes is None:
10781081
decay_amplitudes = calculate_decay_amplitudes(pulse, spectrum, omega, n_oper_identifiers,
10791082
which, show_progressbar, cache_intermediates,
@@ -1517,27 +1520,31 @@ def calculate_second_order_filter_function(
15171520
result = np.zeros(shape, dtype=complex)
15181521

15191522
# intermediate results from calculation of control matrix
1520-
if not intermediates:
1521-
# Require absolut times for calculation of control matrix at step g
1523+
if intermediates is None:
1524+
intermediates = dict()
1525+
1526+
# Work around possibly populated intermediates dict with missing keys
1527+
n_opers_transformed = intermediates.get('n_opers_transformed',
1528+
_transform_hamiltonian(eigvecs, n_opers, n_coeffs))
1529+
try:
1530+
basis_transformed_cache = intermediates['basis_transformed']
1531+
ctrlmat_step_cache = intermediates['control_matrix_step']
1532+
have_intermediates = True
1533+
except KeyError:
1534+
have_intermediates = False
1535+
# No cache. Precompute some things and perform the costly computations
1536+
# during each loop iteration below
15221537
t = np.concatenate(([0], np.asarray(dt).cumsum()))
1523-
# Cheap to precompute as these don't use a lot of memory
15241538
eigvecs_propagated = _propagate_eigenvectors(propagators[:-1], eigvecs)
1525-
n_opers_transformed = _transform_hamiltonian(eigvecs, n_opers, n_coeffs)
1526-
# These are populated anew during every iteration, so there is no need
1527-
# to keep every time step
15281539
basis_transformed = np.empty(basis.shape, dtype=complex)
15291540
ctrlmat_step = np.zeros((len(n_coeffs), len(basis), len(omega)), dtype=complex)
1530-
else:
1531-
n_opers_transformed = intermediates['n_opers_transformed']
1532-
basis_transformed_cache = intermediates['basis_transformed']
1533-
ctrlmat_step_cache = intermediates['control_matrix_step']
15341541

15351542
step_expr = oe.contract_expression('oijmn,akij,blmn->abklo', int_buf.shape,
15361543
*[(len(n_coeffs), len(basis), d, d)]*2,
15371544
optimize=[(0, 1), (0, 1)])
15381545
for g in util.progressbar_range(len(dt), show_progressbar=show_progressbar,
15391546
desc='Calculating second order FF'):
1540-
if not intermediates:
1547+
if not have_intermediates:
15411548
basis_transformed = _transform_by_unitary(eigvecs_propagated[g], basis,
15421549
out=basis_transformed)
15431550
# Need to compute G^(g) since no cache given. First initialize

0 commit comments

Comments
 (0)