Skip to content

Commit ae1c89f

Browse files
committed
Fix sorting and add tests
1 parent 5976b3d commit ae1c89f

File tree

3 files changed

+97
-21
lines changed

3 files changed

+97
-21
lines changed

filter_functions/gradient.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _control_matrix_at_timestep_derivative(
262262
The individual control matrices of all time steps
263263
ctrlmat_g_deriv: ndarray, shape (n_dt, n_nops, d**2, n_ctrl, n_omega)
264264
The corresponding derivative with respect to the control
265-
strength :math:`\frac{\partial\mathcal{B}_{\alpha j}^{(g)}(\omega)}
265+
strength :math:`\frac{\partial\mathcal{B}_{\alpha j}^{(g)}(\omega)}`
266266
267267
Notes
268268
-----
@@ -613,7 +613,9 @@ def infidelity_derivative(
613613
infid_deriv: ndarray, shape (n_nops, n_dt, n_ctrl)
614614
Array with the derivative of the infidelity for each noise
615615
source taken for each control direction at each time step
616-
:math:`\frac{\partial I_e}{\partial u_h(t_{g'})}`.
616+
:math:`\frac{\partial I_e}{\partial u_h(t_{g'})}`. Sorted in
617+
the same fashion as `n_coeffs_deriv` or, if not given,
618+
alphanumerically by the identifiers.
617619
618620
Notes
619621
-----

filter_functions/pulse_sequence.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -901,40 +901,42 @@ def get_filter_function_derivative(
901901
-------
902902
filter_function_deriv: ndarray, shape (n_nops, n_t, n_ctrl, n_omega)
903903
The regular filter functions' derivatives for variation in
904-
each control contribution.
904+
each control contribution. Sorted in the same fashion as
905+
`n_coeffs_deriv` or, if not given, alphanumerically by the
906+
identifiers.
905907
906908
"""
907-
# Distinction between control and drift operators and only
908-
# calculate the derivatives in control direction.
909-
# TODO 05/22: Is the extended error message necessary?
910-
try:
911-
c_idx = util.get_indices_from_identifiers(self.c_oper_identifiers, control_identifiers)
912-
except ValueError as err:
913-
raise ValueError('Given control identifiers have to be a subset of (drift+control) '
914-
+ 'Hamiltonian!') from err
915-
909+
c_idx = util.get_indices_from_identifiers(self.c_oper_identifiers, control_identifiers)
916910
n_idx = util.get_indices_from_identifiers(self.n_oper_identifiers, n_oper_identifiers)
917-
# Identifiers sorted the way they were passed (or as stored internally if not given)
918-
control_identifiers = self.c_oper_identifiers[c_idx]
919-
n_oper_identifiers = self.n_oper_identifiers[n_idx]
920911

921912
if n_coeffs_deriv is not None:
922913
# TODO 05/22: walrus once support for 3.7 is dropped.
923914
actual_shape = np.shape(n_coeffs_deriv)
924915
required_shape = (len(n_idx), len(c_idx), len(self))
925916
if actual_shape != required_shape:
926917
raise ValueError(f'Expected n_coeffs_deriv to be of shape {required_shape}, '
927-
f'not {actual_shape}')
918+
f'not {actual_shape}. Did you forget to specify identifiers?')
928919
else:
929-
# This would be so much cleaner with xarray :(
930-
n_coeffs_deriv = n_coeffs_deriv[np.argsort(n_oper_identifiers)[:, None],
931-
np.argsort(control_identifiers)]
920+
# Do nothing; n_coeffs_deriv specifies the sorting order. If identifiers
921+
# are given, we sort everything else by them. If not, n_coeffs_deriv is
922+
# expected to be sorted in accordance with the opers and coeffs.
923+
pass
932924

933-
control_matrix = self.get_control_matrix(omega, cache_intermediates=True)
925+
# Check if we can pass on intermediates.
926+
intermediates = dict()
927+
# TODO 05/22: walrus once support for 3.7 is dropped.
928+
n_opers_transformed = self._intermediates.get('n_opers_transformed')
929+
first_order_integral = self._intermediates.get('first_order_integral')
930+
if n_opers_transformed is not None:
931+
intermediates['n_opers_transformed'] = n_opers_transformed[n_idx]
932+
if first_order_integral is not None:
933+
intermediates['first_order_integral'] = first_order_integral
934+
935+
control_matrix = self.get_control_matrix(omega, cache_intermediates=True)[n_idx]
934936
control_matrix_deriv = gradient.calculate_derivative_of_control_matrix_from_scratch(
935937
omega, self.propagators, self.eigvals, self.eigvecs, self.basis, self.t, self.dt,
936938
self.n_opers[n_idx], self.n_coeffs[n_idx], self.c_opers[c_idx], n_coeffs_deriv,
937-
self._intermediates
939+
intermediates
938940
)
939941
return gradient.calculate_filter_function_derivative(control_matrix, control_matrix_deriv)
940942

tests/test_gradient.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,72 @@ def test_gradient_calculation_variable_noise_coefficients(self):
5050
)
5151
self.assertArrayAlmostEqual(ana_grad, fin_diff_grad, rtol=1e-6, atol=1e-10)
5252

53+
def test_n_coeffs_deriv_sorting(self):
54+
for _ in range(5):
55+
pulse = testutil.rand_pulse_sequence(testutil.rng.integers(2, 5),
56+
testutil.rng.integers(2, 11))
57+
omega = ff.util.get_sample_frequencies(pulse, n_samples=37)
58+
59+
# Not the correct derivative, but irrelevant for comparison between analytics
60+
n_coeffs_deriv = testutil.rng.normal(size=(len(pulse.n_opers),
61+
len(pulse.c_opers),
62+
len(pulse)))
63+
64+
# indices to sort sorted opers into a hypothetical unsorted original order.
65+
n_oper_unsort_idx = np.random.permutation(np.arange(len(pulse.n_opers)))
66+
c_oper_unsort_idx = np.random.permutation(np.arange(len(pulse.c_opers)))
67+
68+
# subset of c_opers and n_opers to compute the derivative for
69+
n_choice = np.random.choice(np.arange(len(pulse.n_opers)),
70+
testutil.rng.integers(1, len(pulse.n_opers) + 1),
71+
replace=False)
72+
c_choice = np.random.choice(np.arange(len(pulse.c_opers)),
73+
testutil.rng.integers(1, len(pulse.c_opers) + 1),
74+
replace=False)
75+
76+
grad = pulse.get_filter_function_derivative(
77+
omega,
78+
n_coeffs_deriv=n_coeffs_deriv
79+
)
80+
grad_as_given = pulse.get_filter_function_derivative(
81+
omega,
82+
n_oper_identifiers=pulse.n_oper_identifiers[n_oper_unsort_idx],
83+
control_identifiers=pulse.c_oper_identifiers[c_oper_unsort_idx],
84+
n_coeffs_deriv=n_coeffs_deriv[n_oper_unsort_idx[:, None], c_oper_unsort_idx]
85+
)
86+
grad_n_choice = pulse.get_filter_function_derivative(
87+
omega,
88+
n_oper_identifiers=pulse.n_oper_identifiers[n_choice],
89+
n_coeffs_deriv=n_coeffs_deriv[n_choice]
90+
)
91+
grad_c_choice = pulse.get_filter_function_derivative(
92+
omega,
93+
control_identifiers=pulse.c_oper_identifiers[c_choice],
94+
n_coeffs_deriv=n_coeffs_deriv[:, c_choice]
95+
)
96+
grad_nc_choice = pulse.get_filter_function_derivative(
97+
omega,
98+
control_identifiers=pulse.c_oper_identifiers[c_choice],
99+
n_oper_identifiers=pulse.n_oper_identifiers[n_choice],
100+
n_coeffs_deriv=n_coeffs_deriv[n_choice[:, None], c_choice]
101+
)
102+
self.assertArrayAlmostEqual(
103+
grad[np.ix_(n_oper_unsort_idx, np.arange(len(pulse)), c_oper_unsort_idx)],
104+
grad_as_given
105+
)
106+
self.assertArrayAlmostEqual(
107+
grad[np.ix_(n_choice, np.arange(len(pulse)))],
108+
grad_n_choice
109+
)
110+
self.assertArrayAlmostEqual(
111+
grad[np.ix_(np.arange(len(pulse.n_opers)), np.arange(len(pulse)), c_choice)],
112+
grad_c_choice
113+
)
114+
self.assertArrayAlmostEqual(
115+
grad[np.ix_(n_choice, np.arange(len(pulse)), c_choice)],
116+
grad_nc_choice
117+
)
118+
53119
def test_gradient_calculation_random_pulse(self):
54120

55121
for d, n_dt in zip(testutil.rng.integers(2, 5, 5), testutil.rng.integers(2, 8, 5)):
@@ -105,3 +171,9 @@ def test_raises(self):
105171
omega = ff.util.get_sample_frequencies(pulse, n_samples=13)
106172
with self.assertRaises(ValueError):
107173
ff.infidelity_derivative(pulse, 1/omega, omega, control_identifiers=['long string'])
174+
175+
with self.assertRaises(ValueError):
176+
pulse.get_filter_function_derivative(
177+
omega,
178+
n_coeffs_deriv=testutil.rng.normal(size=(2, 5, 10))
179+
)

0 commit comments

Comments
 (0)