Skip to content

Commit 65cd5ce

Browse files
committed
Merge branch 'master' into feature/update_examples_for_paper
2 parents 3aff9f4 + 3a3bdbc commit 65cd5ce

File tree

7 files changed

+95
-20
lines changed

7 files changed

+95
-20
lines changed

filter_functions/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222

2323
from . import analytic, basis, numeric, pulse_sequence, superoperator, util
2424
from .basis import Basis
25-
25+
from .gradient import infidelity_derivative
2626
from .numeric import error_transfer_matrix, infidelity
2727
from .pulse_sequence import PulseSequence, concatenate, concatenate_periodic, extend, remap
2828
from .superoperator import liouville_representation
29-
from .gradient import infidelity_derivative
3029

3130
__all__ = ['Basis', 'PulseSequence', 'analytic', 'basis', 'concatenate', 'concatenate_periodic',
3231
'error_transfer_matrix', 'extend', 'infidelity', 'liouville_representation', 'numeric',

filter_functions/plotting.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
from numpy import ndarray
5555

5656
from . import numeric, util
57-
from .types import (Axes, Coefficients, Colormap, Figure, FigureAxes,
58-
FigureAxesLegend, FigureGrid, Grid, Operator, State)
57+
from .types import (Axes, Coefficients, Colormap, Figure, FigureAxes, FigureAxesLegend, FigureGrid,
58+
Grid, Operator, State)
5959

6060
__all__ = ['plot_cumulant_function', 'plot_infidelity_convergence', 'plot_filter_function',
6161
'plot_pulse_correlation_filter_function', 'plot_pulse_train']
@@ -68,6 +68,24 @@
6868
qt = mock.Mock()
6969

7070

71+
def _make_str_tex_compatible(s: str) -> str:
72+
"""Escape incompatible characters in strings passed to TeX."""
73+
if not plt.rcParams['text.usetex']:
74+
return s
75+
76+
s = str(s)
77+
incompatible = ('_',)
78+
for char in incompatible:
79+
locs = [i for i, c in enumerate(s) if c == char]
80+
# Loop backwards so as not to change locs when modifying s
81+
for loc in locs[::-1]:
82+
# Check if math environment, if not add escape character
83+
if not s.count('$', loc) % 2:
84+
s = s[:loc] + '\\' + s[loc:]
85+
86+
return s
87+
88+
7189
def get_bloch_vector(states: Sequence[State]) -> ndarray:
7290
r"""
7391
Get the Bloch vector from quantum states.
@@ -247,6 +265,7 @@ def plot_pulse_train(
247265
c_oper_identifiers: Optional[Sequence[int]] = None,
248266
fig: Optional[Figure] = None,
249267
axes: Optional[Axes] = None,
268+
cycler: Optional['cycler.Cycler'] = None,
250269
plot_kw: Optional[dict] = {},
251270
subplot_kw: Optional[dict] = None,
252271
gridspec_kw: Optional[dict] = None,
@@ -267,6 +286,9 @@ def plot_pulse_train(
267286
A matplotlib figure instance to plot in
268287
axes: matplotlib axes, optional
269288
A matplotlib axes instance to use for plotting.
289+
cycler: cycler.Cycler, optional
290+
A Cycler instance used to set the style cycle if multiple lines
291+
are to be drawn
270292
plot_kw: dict, optional
271293
Dictionary with keyword arguments passed to the plot function
272294
subplot_kw: dict, optional
@@ -307,10 +329,14 @@ def plot_pulse_train(
307329
elif fig is None and axes is not None:
308330
fig = axes.figure
309331

332+
if cycler is not None:
333+
axes.set_prop_cycle(cycler)
334+
310335
handles = []
311336
for i, c_coeffs in enumerate(pulse.c_coeffs[tuple(c_oper_inds), ...]):
312337
coeffs = np.insert(c_coeffs, 0, c_coeffs[0])
313-
handles += axes.step(pulse.t, coeffs, label=c_oper_identifiers[i], **plot_kw)
338+
handles += axes.step(pulse.t, coeffs,
339+
label=_make_str_tex_compatible(c_oper_identifiers[i]), **plot_kw)
314340

315341
axes.set_xlim(pulse.t[0], pulse.tau)
316342
axes.set_xlabel(r'$t$ / a.u.')
@@ -330,6 +356,7 @@ def plot_filter_function(
330356
xscale: str = 'log',
331357
yscale: str = 'linear',
332358
omega_in_units_of_tau: bool = True,
359+
cycler: Optional['cycler.Cycler'] = None,
333360
plot_kw: dict = {},
334361
subplot_kw: Optional[dict] = None,
335362
gridspec_kw: Optional[dict] = None,
@@ -363,6 +390,9 @@ def plot_filter_function(
363390
y-axis scaling. One of ('linear', 'log').
364391
omega_in_units_of_tau: bool, optional
365392
Plot :math:`\omega\tau` or just :math:`\omega` on x-axis.
393+
cycler: cycler.Cycler, optional
394+
A Cycler instance used to set the style cycle if multiple lines
395+
are to be drawn
366396
plot_kw: dict, optional
367397
Dictionary with keyword arguments passed to the plot function
368398
subplot_kw: dict, optional
@@ -409,6 +439,9 @@ def plot_filter_function(
409439
elif fig is None and axes is not None:
410440
fig = axes.figure
411441

442+
if cycler is not None:
443+
axes.set_prop_cycle(cycler)
444+
412445
if omega_in_units_of_tau:
413446
tau = np.ptp(pulse.t)
414447
z = omega*tau
@@ -423,7 +456,8 @@ def plot_filter_function(
423456
handles = []
424457
for i, ind in enumerate(n_oper_inds):
425458
handles += axes.plot(z, filter_function[ind],
426-
label=n_oper_identifiers[i], **plot_kw)
459+
label=_make_str_tex_compatible(n_oper_identifiers[i]),
460+
**plot_kw)
427461

428462
# Set the axis scales
429463
axes.set_xscale(xscale)
@@ -452,6 +486,7 @@ def plot_pulse_correlation_filter_function(
452486
xscale: str = 'log',
453487
yscale: str = 'linear',
454488
omega_in_units_of_tau: bool = True,
489+
cycler: Optional['cycler.Cycler'] = None,
455490
plot_kw: dict = {},
456491
subplot_kw: Optional[dict] = None,
457492
gridspec_kw: Optional[dict] = None,
@@ -483,6 +518,9 @@ def plot_pulse_correlation_filter_function(
483518
y-axis scaling. One of ('linear', 'log').
484519
omega_in_units_of_tau: bool, optional
485520
Plot :math:`\omega\tau` or just :math:`\omega` on x-axis.
521+
cycler: cycler.Cycler, optional
522+
A Cycler instance used to set the style cycle if multiple lines
523+
are to be drawn in one subplot. Used for all subplots.
486524
plot_kw: dict, optional
487525
Dictionary with keyword arguments passed to the plot function
488526
subplot_kw: dict, optional
@@ -546,10 +584,13 @@ def plot_pulse_correlation_filter_function(
546584
dashed_line = lines.Line2D([], [], color='gray', linestyle='--')
547585
for i in range(n):
548586
for j in range(n):
587+
if cycler is not None:
588+
axes[i, j].set_prop_cycle(cycler)
589+
549590
handles = []
550591
for k, ind in enumerate(n_oper_inds):
551592
handles += axes[i, j].plot(z, F_pc[i, j, ind].real,
552-
label=n_oper_identifiers[k],
593+
label=_make_str_tex_compatible(n_oper_identifiers[k]),
553594
**plot_kw)
554595
if i != j:
555596
axes[i, j].plot(z, F_pc[i, j, ind].imag, linestyle='--',
@@ -566,7 +607,8 @@ def plot_pulse_correlation_filter_function(
566607

567608
if i == 0 and j == n-1:
568609
handles += [transparent_line, solid_line, dashed_line]
569-
labels = n_oper_identifiers.tolist() + ['', r'$Re$', r'$Im$']
610+
labels = ([_make_str_tex_compatible(n) for n in n_oper_identifiers]
611+
+ ['', r'$Re$', r'$Im$'])
570612
legend = axes[i, j].legend(handles=handles, labels=labels,
571613
bbox_to_anchor=(1.05, 1), loc=2,
572614
borderaxespad=0., frameon=False)
@@ -628,11 +670,12 @@ def plot_cumulant_function(
628670
omega: Optional[Coefficients] = None,
629671
cumulant_function: Optional[ndarray] = None,
630672
n_oper_identifiers: Optional[Sequence[int]] = None,
631-
basis_labels: Optional[Sequence[str]] = None,
632673
colorscale: str = 'linear',
633674
linthresh: Optional[float] = None,
634-
cbar_label: str = 'Cumulant Function',
675+
basis_labels: Optional[Sequence[str]] = None,
635676
basis_labelsize: Optional[int] = None,
677+
cbar_label: str = 'Cumulant Function',
678+
cbar_labelsize: Optional[int] = None,
636679
fig: Optional[Figure] = None,
637680
grid: Optional[Grid] = None,
638681
cmap: Optional[Colormap] = None,
@@ -669,18 +712,20 @@ def plot_cumulant_function(
669712
The identifiers of the noise operators for which the cumulant
670713
function should be plotted. All identifiers can be accessed via
671714
``pulse.n_oper_identifiers``. Defaults to all.
672-
basis_labels: array_like (str), optional
673-
Labels for the elements of the cumulant function (the basis
674-
elements).
675715
colorscale: str, optional
676716
The scale of the color code ('linear' or 'log' (default))
677717
linthresh: float, optional
678718
The threshold below which the colorscale will be linear (only
679719
for 'log') colorscale
680-
cbar_label: str, optional
681-
The label for the colorbar. Default: 'Cumulant Function'.
720+
basis_labels: array_like (str), optional
721+
Labels for the elements of the cumulant function (the basis
722+
elements).
682723
basis_labelsize: int, optional
683724
The size in points for the basis labels.
725+
cbar_label: str, optional
726+
The label for the colorbar. Default: 'Cumulant Function'.
727+
cbar_labelsize: int, optional
728+
The size in points for the colorbar label.
684729
fig: matplotlib figure, optional
685730
A matplotlib figure instance to plot in
686731
grid: matplotlib ImageGrid, optional
@@ -752,6 +797,8 @@ def plot_cumulant_function(
752797
if len(basis_labels) != K.shape[-1]:
753798
raise ValueError('Invalid number of basis_labels given')
754799

800+
basis_labels = [_make_str_tex_compatible(bl) for bl in basis_labels]
801+
755802
if grid is None:
756803
aspect_ratio = 2/3
757804
n_rows = int(np.round(np.sqrt(aspect_ratio*len(n_oper_inds))))
@@ -799,6 +846,7 @@ def plot_cumulant_function(
799846
imshow_kw.setdefault('norm', norm)
800847

801848
basis_labelsize = basis_labelsize or 8
849+
cbar_labelsize = cbar_labelsize or plt.rcParams['axes.labelsize']
802850

803851
# Draw the images
804852
for i, n_oper_identifier in enumerate(n_oper_identifiers):
@@ -818,6 +866,6 @@ def plot_cumulant_function(
818866
cbar_kw = cbar_kw or {}
819867
cbar_kw.setdefault('orientation', 'vertical')
820868
cbar = fig.colorbar(im, cax=grid.cbar_axes[0], **cbar_kw)
821-
cbar.set_label(cbar_label)
869+
cbar.set_label(_make_str_tex_compatible(cbar_label), fontsize=cbar_labelsize)
822870

823871
return fig, grid

filter_functions/pulse_sequence.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from numpy import linalg as nla
5050
from numpy import ndarray
5151

52-
from . import numeric, util, gradient
52+
from . import gradient, numeric, util
5353
from .basis import Basis, equivalent_pauli_basis_elements, remap_pauli_basis_elements
5454
from .superoperator import liouville_representation
5555
from .types import Coefficients, Hamiltonian, Operator, PulseMapping
@@ -2250,7 +2250,8 @@ def extend(
22502250
raise ValueError(f'Expected additional noise operators to have dimensions {(d, d)}, ' +
22512251
f'not {add_n_opers.shape[1:]}.')
22522252
if any(n_oper_id in n_oper_identifiers for n_oper_id in add_n_oper_id):
2253-
raise ValueError('Found duplicate noise operator identifiers')
2253+
identifiers = set(n_oper_identifiers).intersection(add_n_oper_id)
2254+
raise ValueError(f'Found duplicate noise operator identifiers: {identifiers}')
22542255

22552256
n_opers.extend(add_n_opers)
22562257
n_coeffs.extend(add_n_coeffs)

tests/gradient_testutil.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
23
import filter_functions as ff
34

45
sigma_x = np.asarray([[0, 1], [1, 0]]) / 2

tests/test_extras.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ class MissingExtrasTest(testutil.TestCase):
4040
'fancy_progressbar' in os.environ.get('INSTALL_EXTRAS', all_extras),
4141
reason='Skipping tests for missing fancy progressbar extra in build with requests') # noqa
4242
def test_fancy_progressbar_not_available(self):
43-
from filter_functions import util
4443
from tqdm import tqdm
44+
45+
from filter_functions import util
4546
self.assertEqual(util._NOTEBOOK_NAME, '')
4647
self.assertIs(tqdm, util._tqdm)
4748

tests/test_gradient.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import tests.gradient_testutil as grad_util
2727
from tests import testutil
2828

29-
3029
np.random.seed(0)
3130
initial_pulse = np.random.rand(grad_util.n_time_steps)
3231
initial_pulse = np.expand_dims(initial_pulse, 0)

tests/test_plotting.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
This module tests the plotting functionality of the package.
2323
"""
2424
import string
25+
from copy import copy
2526
from random import sample
2627

2728
import numpy as np
@@ -38,6 +39,7 @@
3839
reason='Skipping plotting tests for build without matplotlib')
3940
if plotting is not None:
4041
import matplotlib.pyplot as plt
42+
from matplotlib import cycler
4143

4244
simple_pulse = testutil.rand_pulse_sequence(2, 1, 1, 1, btype='Pauli')
4345
complicated_pulse = testutil.rand_pulse_sequence(2, 100, 3, 3)
@@ -67,6 +69,10 @@ def test_plot_pulse_train(self):
6769
c_oper_identifiers,
6870
fig=fig, axes=ax)
6971

72+
# Test cycler arg
73+
cycle = cycler(color=['r', 'g', 'b'])
74+
fig, ax, leg = plotting.plot_pulse_train(simple_pulse, cycler=cycle)
75+
7076
# invalid identifier
7177
with self.assertRaises(ValueError):
7278
plotting.plot_pulse_train(complicated_pulse,
@@ -113,6 +119,10 @@ def test_plot_filter_function(self):
113119
fig=fig, axes=ax, omega_in_units_of_tau=False
114120
)
115121

122+
# Test cycler arg
123+
cycle = cycler(color=['r', 'g', 'b'])
124+
fig, ax, leg = plotting.plot_filter_function(simple_pulse, cycler=cycle)
125+
116126
# invalid identifier
117127
with self.assertRaises(ValueError):
118128
plotting.plot_filter_function(complicated_pulse,
@@ -177,6 +187,11 @@ def test_plot_pulse_correlation_filter_function(self):
177187
omega_in_units_of_tau=False
178188
)
179189

190+
# Test cycler arg
191+
cycle = cycler(color=['r', 'g', 'b'])
192+
fig, ax, leg = plotting.plot_pulse_correlation_filter_function(concatenated_simple_pulse,
193+
cycler=cycle)
194+
180195
# invalid identifiers
181196
with self.assertRaises(ValueError):
182197
plotting.plot_pulse_correlation_filter_function(
@@ -299,6 +314,17 @@ def spectrum(omega):
299314
fig, ax = plotting.plot_infidelity_convergence(n, infids)
300315

301316

317+
class LaTeXRenderingTest(testutil.TestCase):
318+
319+
def test_plot_filter_function(self):
320+
pulse = copy(simple_pulse)
321+
pulse.c_oper_identifiers = np.array([f'B_{i}' for i in range(len(pulse.c_opers))])
322+
pulse.n_oper_identifiers = np.array([f'B_{i}' for i in range(len(pulse.n_opers))])
323+
with plt.rc_context(rc={'text.usetex': True}):
324+
_ = plotting.plot_pulse_train(pulse)
325+
_ = plotting.plot_filter_function(pulse)
326+
327+
302328
@pytest.mark.skipif(
303329
qutip is None,
304330
reason='Skipping bloch sphere visualization tests for build without qutip')

0 commit comments

Comments
 (0)