54
54
from numpy import ndarray
55
55
56
56
from . import numeric , util
57
- from .types import (Axes , Coefficients , Colormap , Figure , FigureAxes ,
58
- FigureAxesLegend , FigureGrid , Grid , Operator , State )
57
+ from .types import (Axes , Coefficients , Colormap , Figure , FigureAxes , FigureAxesLegend , FigureGrid ,
58
+ Grid , Operator , State )
59
59
60
60
__all__ = ['plot_cumulant_function' , 'plot_infidelity_convergence' , 'plot_filter_function' ,
61
61
'plot_pulse_correlation_filter_function' , 'plot_pulse_train' ]
68
68
qt = mock .Mock ()
69
69
70
70
71
+ def _make_str_tex_compatible (s : str ) -> str :
72
+ """Escape incompatible characters in strings passed to TeX."""
73
+ if not plt .rcParams ['text.usetex' ]:
74
+ return s
75
+
76
+ s = str (s )
77
+ incompatible = ('_' ,)
78
+ for char in incompatible :
79
+ locs = [i for i , c in enumerate (s ) if c == char ]
80
+ # Loop backwards so as not to change locs when modifying s
81
+ for loc in locs [::- 1 ]:
82
+ # Check if math environment, if not add escape character
83
+ if not s .count ('$' , loc ) % 2 :
84
+ s = s [:loc ] + '\\ ' + s [loc :]
85
+
86
+ return s
87
+
88
+
71
89
def get_bloch_vector (states : Sequence [State ]) -> ndarray :
72
90
r"""
73
91
Get the Bloch vector from quantum states.
@@ -247,6 +265,7 @@ def plot_pulse_train(
247
265
c_oper_identifiers : Optional [Sequence [int ]] = None ,
248
266
fig : Optional [Figure ] = None ,
249
267
axes : Optional [Axes ] = None ,
268
+ cycler : Optional ['cycler.Cycler' ] = None ,
250
269
plot_kw : Optional [dict ] = {},
251
270
subplot_kw : Optional [dict ] = None ,
252
271
gridspec_kw : Optional [dict ] = None ,
@@ -267,6 +286,9 @@ def plot_pulse_train(
267
286
A matplotlib figure instance to plot in
268
287
axes: matplotlib axes, optional
269
288
A matplotlib axes instance to use for plotting.
289
+ cycler: cycler.Cycler, optional
290
+ A Cycler instance used to set the style cycle if multiple lines
291
+ are to be drawn
270
292
plot_kw: dict, optional
271
293
Dictionary with keyword arguments passed to the plot function
272
294
subplot_kw: dict, optional
@@ -307,10 +329,14 @@ def plot_pulse_train(
307
329
elif fig is None and axes is not None :
308
330
fig = axes .figure
309
331
332
+ if cycler is not None :
333
+ axes .set_prop_cycle (cycler )
334
+
310
335
handles = []
311
336
for i , c_coeffs in enumerate (pulse .c_coeffs [tuple (c_oper_inds ), ...]):
312
337
coeffs = np .insert (c_coeffs , 0 , c_coeffs [0 ])
313
- handles += axes .step (pulse .t , coeffs , label = c_oper_identifiers [i ], ** plot_kw )
338
+ handles += axes .step (pulse .t , coeffs ,
339
+ label = _make_str_tex_compatible (c_oper_identifiers [i ]), ** plot_kw )
314
340
315
341
axes .set_xlim (pulse .t [0 ], pulse .tau )
316
342
axes .set_xlabel (r'$t$ / a.u.' )
@@ -330,6 +356,7 @@ def plot_filter_function(
330
356
xscale : str = 'log' ,
331
357
yscale : str = 'linear' ,
332
358
omega_in_units_of_tau : bool = True ,
359
+ cycler : Optional ['cycler.Cycler' ] = None ,
333
360
plot_kw : dict = {},
334
361
subplot_kw : Optional [dict ] = None ,
335
362
gridspec_kw : Optional [dict ] = None ,
@@ -363,6 +390,9 @@ def plot_filter_function(
363
390
y-axis scaling. One of ('linear', 'log').
364
391
omega_in_units_of_tau: bool, optional
365
392
Plot :math:`\omega\tau` or just :math:`\omega` on x-axis.
393
+ cycler: cycler.Cycler, optional
394
+ A Cycler instance used to set the style cycle if multiple lines
395
+ are to be drawn
366
396
plot_kw: dict, optional
367
397
Dictionary with keyword arguments passed to the plot function
368
398
subplot_kw: dict, optional
@@ -409,6 +439,9 @@ def plot_filter_function(
409
439
elif fig is None and axes is not None :
410
440
fig = axes .figure
411
441
442
+ if cycler is not None :
443
+ axes .set_prop_cycle (cycler )
444
+
412
445
if omega_in_units_of_tau :
413
446
tau = np .ptp (pulse .t )
414
447
z = omega * tau
@@ -423,7 +456,8 @@ def plot_filter_function(
423
456
handles = []
424
457
for i , ind in enumerate (n_oper_inds ):
425
458
handles += axes .plot (z , filter_function [ind ],
426
- label = n_oper_identifiers [i ], ** plot_kw )
459
+ label = _make_str_tex_compatible (n_oper_identifiers [i ]),
460
+ ** plot_kw )
427
461
428
462
# Set the axis scales
429
463
axes .set_xscale (xscale )
@@ -452,6 +486,7 @@ def plot_pulse_correlation_filter_function(
452
486
xscale : str = 'log' ,
453
487
yscale : str = 'linear' ,
454
488
omega_in_units_of_tau : bool = True ,
489
+ cycler : Optional ['cycler.Cycler' ] = None ,
455
490
plot_kw : dict = {},
456
491
subplot_kw : Optional [dict ] = None ,
457
492
gridspec_kw : Optional [dict ] = None ,
@@ -483,6 +518,9 @@ def plot_pulse_correlation_filter_function(
483
518
y-axis scaling. One of ('linear', 'log').
484
519
omega_in_units_of_tau: bool, optional
485
520
Plot :math:`\omega\tau` or just :math:`\omega` on x-axis.
521
+ cycler: cycler.Cycler, optional
522
+ A Cycler instance used to set the style cycle if multiple lines
523
+ are to be drawn in one subplot. Used for all subplots.
486
524
plot_kw: dict, optional
487
525
Dictionary with keyword arguments passed to the plot function
488
526
subplot_kw: dict, optional
@@ -546,10 +584,13 @@ def plot_pulse_correlation_filter_function(
546
584
dashed_line = lines .Line2D ([], [], color = 'gray' , linestyle = '--' )
547
585
for i in range (n ):
548
586
for j in range (n ):
587
+ if cycler is not None :
588
+ axes [i , j ].set_prop_cycle (cycler )
589
+
549
590
handles = []
550
591
for k , ind in enumerate (n_oper_inds ):
551
592
handles += axes [i , j ].plot (z , F_pc [i , j , ind ].real ,
552
- label = n_oper_identifiers [k ],
593
+ label = _make_str_tex_compatible ( n_oper_identifiers [k ]) ,
553
594
** plot_kw )
554
595
if i != j :
555
596
axes [i , j ].plot (z , F_pc [i , j , ind ].imag , linestyle = '--' ,
@@ -566,7 +607,8 @@ def plot_pulse_correlation_filter_function(
566
607
567
608
if i == 0 and j == n - 1 :
568
609
handles += [transparent_line , solid_line , dashed_line ]
569
- labels = n_oper_identifiers .tolist () + ['' , r'$Re$' , r'$Im$' ]
610
+ labels = ([_make_str_tex_compatible (n ) for n in n_oper_identifiers ]
611
+ + ['' , r'$Re$' , r'$Im$' ])
570
612
legend = axes [i , j ].legend (handles = handles , labels = labels ,
571
613
bbox_to_anchor = (1.05 , 1 ), loc = 2 ,
572
614
borderaxespad = 0. , frameon = False )
@@ -628,11 +670,12 @@ def plot_cumulant_function(
628
670
omega : Optional [Coefficients ] = None ,
629
671
cumulant_function : Optional [ndarray ] = None ,
630
672
n_oper_identifiers : Optional [Sequence [int ]] = None ,
631
- basis_labels : Optional [Sequence [str ]] = None ,
632
673
colorscale : str = 'linear' ,
633
674
linthresh : Optional [float ] = None ,
634
- cbar_label : str = 'Cumulant Function' ,
675
+ basis_labels : Optional [ Sequence [ str ]] = None ,
635
676
basis_labelsize : Optional [int ] = None ,
677
+ cbar_label : str = 'Cumulant Function' ,
678
+ cbar_labelsize : Optional [int ] = None ,
636
679
fig : Optional [Figure ] = None ,
637
680
grid : Optional [Grid ] = None ,
638
681
cmap : Optional [Colormap ] = None ,
@@ -669,18 +712,20 @@ def plot_cumulant_function(
669
712
The identifiers of the noise operators for which the cumulant
670
713
function should be plotted. All identifiers can be accessed via
671
714
``pulse.n_oper_identifiers``. Defaults to all.
672
- basis_labels: array_like (str), optional
673
- Labels for the elements of the cumulant function (the basis
674
- elements).
675
715
colorscale: str, optional
676
716
The scale of the color code ('linear' or 'log' (default))
677
717
linthresh: float, optional
678
718
The threshold below which the colorscale will be linear (only
679
719
for 'log') colorscale
680
- cbar_label: str, optional
681
- The label for the colorbar. Default: 'Cumulant Function'.
720
+ basis_labels: array_like (str), optional
721
+ Labels for the elements of the cumulant function (the basis
722
+ elements).
682
723
basis_labelsize: int, optional
683
724
The size in points for the basis labels.
725
+ cbar_label: str, optional
726
+ The label for the colorbar. Default: 'Cumulant Function'.
727
+ cbar_labelsize: int, optional
728
+ The size in points for the colorbar label.
684
729
fig: matplotlib figure, optional
685
730
A matplotlib figure instance to plot in
686
731
grid: matplotlib ImageGrid, optional
@@ -752,6 +797,8 @@ def plot_cumulant_function(
752
797
if len (basis_labels ) != K .shape [- 1 ]:
753
798
raise ValueError ('Invalid number of basis_labels given' )
754
799
800
+ basis_labels = [_make_str_tex_compatible (bl ) for bl in basis_labels ]
801
+
755
802
if grid is None :
756
803
aspect_ratio = 2 / 3
757
804
n_rows = int (np .round (np .sqrt (aspect_ratio * len (n_oper_inds ))))
@@ -799,6 +846,7 @@ def plot_cumulant_function(
799
846
imshow_kw .setdefault ('norm' , norm )
800
847
801
848
basis_labelsize = basis_labelsize or 8
849
+ cbar_labelsize = cbar_labelsize or plt .rcParams ['axes.labelsize' ]
802
850
803
851
# Draw the images
804
852
for i , n_oper_identifier in enumerate (n_oper_identifiers ):
@@ -818,6 +866,6 @@ def plot_cumulant_function(
818
866
cbar_kw = cbar_kw or {}
819
867
cbar_kw .setdefault ('orientation' , 'vertical' )
820
868
cbar = fig .colorbar (im , cax = grid .cbar_axes [0 ], ** cbar_kw )
821
- cbar .set_label (cbar_label )
869
+ cbar .set_label (_make_str_tex_compatible ( cbar_label ), fontsize = cbar_labelsize )
822
870
823
871
return fig , grid
0 commit comments