42
42
spectrum as an image.
43
43
44
44
"""
45
+ from packaging import version
45
46
from typing import Optional , Sequence , Union
46
47
from unittest import mock
47
48
from warnings import warn
48
49
50
+ import matplotlib
49
51
import matplotlib .pyplot as plt
50
52
import numpy as np
51
- from matplotlib import colors , lines # , collections
53
+ from matplotlib import cm , collections , colors , lines
52
54
from mpl_toolkits import axes_grid1
53
55
from numpy import ndarray
54
56
@@ -170,8 +172,11 @@ def plot_bloch_vector_evolution(
170
172
psi0 : Optional [State ] = None ,
171
173
b : Optional [qt .Bloch ] = None ,
172
174
n_samples : Optional [int ] = None ,
173
- cmap : Optional [Colormap ] = None ,
174
- show : bool = True , return_Bloch : bool = False ,
175
+ cmap : Colormap = 'winter' ,
176
+ add_cbar : bool = False ,
177
+ show : bool = True ,
178
+ return_Bloch : bool = False ,
179
+ cbar_kwargs : Optional [dict ] = None ,
175
180
** bloch_kwargs
176
181
) -> Union [None , qt .Bloch ]:
177
182
r"""
@@ -192,8 +197,11 @@ def plot_bloch_vector_evolution(
192
197
n_samples: int, optional
193
198
The number of time points to be sampled.
194
199
cmap: matplotlib colormap, optional
195
- The colormap for the trajectory.
196
- show**: bool, optional
200
+ The colormap for the trajectory. Requires ``matplotlib >= 3.3.0``.
201
+ add_cbar: bool, optional
202
+ Add a colorbar encoding the time evolution to the figure.
203
+ Default is false.
204
+ show: bool, optional
197
205
Whether to show the sphere (by calling :code:`b.make_sphere()`).
198
206
return_Bloch: bool, optional
199
207
Whether to return the :class:`qutip.bloch.Bloch` instance
@@ -220,12 +228,17 @@ def plot_bloch_vector_evolution(
220
228
raise ValueError ('Plotting Bloch sphere evolution only implemented for one-qubit case!' )
221
229
222
230
# Parse default arguments
231
+ figsize = bloch_kwargs .pop ('figsize' , [5 , 5 ])
232
+ view = bloch_kwargs .pop ('view' , [- 60 , 30 ])
223
233
if b is None :
224
- figsize = bloch_kwargs .pop ('figsize' , [5 , 5 ])
225
- view = bloch_kwargs .pop ('view' , [- 60 , 30 ])
226
234
fig = plt .figure (figsize = figsize )
227
235
axes = fig .add_subplot (projection = '3d' , azim = view [0 ], elev = view [1 ])
228
236
b = init_bloch_sphere (fig = fig , axes = axes , ** bloch_kwargs )
237
+ else :
238
+ if b .fig is None :
239
+ b .fig = plt .figure (figsize = figsize )
240
+ if b .axes is None :
241
+ b .axes = b .fig .add_subplot (projection = '3d' , azim = view [0 ], elev = view [1 ])
229
242
230
243
if n_samples is None :
231
244
# At least 100, at most 5000 points, default 10 points per smallest
@@ -235,21 +248,27 @@ def plot_bloch_vector_evolution(
235
248
times = np .linspace (pulse .t [0 ], pulse .tau , n_samples )
236
249
propagators = pulse .propagator_at_arb_t (times )
237
250
points = get_bloch_vector (get_states_from_prop (propagators , psi0 ))
238
- b .add_points (points , meth = 'l' )
239
-
240
- # The following enables a color gradient for the trajectory, but only works
241
- # by patching matplotlib, see
242
- # https://github.com/matplotlib/matplotlib/issues/17755
243
- # points = get_bloch_vector(get_states_from_prop(propagators, psi0)).T.reshape(-1, 1, 3)
244
- # points[:, :, 1] *= -1 # qutip convention
245
- # segments = np.concatenate([points[:-1], points[1:]], axis=1)
246
251
247
- # if cmap is None:
248
- # cmap = plt.get_cmap('winter')
249
-
250
- # colors = cmap(np.linspace(0, 1, n_samples - 1))
251
- # lc = collections.LineCollection(segments[:, :, :2], colors=colors)
252
- # b.axes.add_collection3d(lc, zdir='z', zs=segments[:, :, 2])
252
+ if version .parse (matplotlib .__version__ ) < version .parse ('3.3.0' ):
253
+ # Colored trajectory not available.
254
+ b .add_points (points , meth = 'l' )
255
+ else :
256
+ points = points .T .reshape (- 1 , 1 , 3 )
257
+ # Qutip convention: -x at +y, +y at +x
258
+ copy = points .copy ()
259
+ points [:, :, 0 ] = copy [:, :, 1 ]
260
+ points [:, :, 1 ] = - copy [:, :, 0 ]
261
+ segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
262
+
263
+ cmap = plt .get_cmap (cmap )
264
+ segment_colors = cmap (np .linspace (0 , 1 , n_samples - 1 ))
265
+ lc = collections .LineCollection (segments [:, :, :2 ], colors = segment_colors )
266
+ b .axes .add_collection3d (lc , zdir = 'z' , zs = segments [:, :, 2 ])
267
+
268
+ if add_cbar :
269
+ default_cbar_kwargs = dict (shrink = 2 / 3 , pad = 0.05 , label = r'$t$ ($\tau$)' , ticks = [0 , 1 ])
270
+ cbar_kwargs = {** default_cbar_kwargs , ** (cbar_kwargs or {})}
271
+ b .fig .colorbar (cm .ScalarMappable (cmap = cmap ), ** cbar_kwargs )
253
272
254
273
if show :
255
274
b .make_sphere ()
0 commit comments