Skip to content

Commit 0840e93

Browse files
authored
Merge pull request #78 from qutech/feature/colored_bloch_trajectory
Colored Bloch sphere trajectory
2 parents 6dd4c2a + ad5acd2 commit 0840e93

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

filter_functions/plotting.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@
4242
spectrum as an image.
4343
4444
"""
45+
from packaging import version
4546
from typing import Optional, Sequence, Union
4647
from unittest import mock
4748
from warnings import warn
4849

50+
import matplotlib
4951
import matplotlib.pyplot as plt
5052
import numpy as np
51-
from matplotlib import colors, lines # , collections
53+
from matplotlib import cm, collections, colors, lines
5254
from mpl_toolkits import axes_grid1
5355
from numpy import ndarray
5456

@@ -170,8 +172,11 @@ def plot_bloch_vector_evolution(
170172
psi0: Optional[State] = None,
171173
b: Optional[qt.Bloch] = None,
172174
n_samples: Optional[int] = None,
173-
cmap: Optional[Colormap] = None,
174-
show: bool = True, return_Bloch: bool = False,
175+
cmap: Colormap = 'winter',
176+
add_cbar: bool = False,
177+
show: bool = True,
178+
return_Bloch: bool = False,
179+
cbar_kwargs: Optional[dict] = None,
175180
**bloch_kwargs
176181
) -> Union[None, qt.Bloch]:
177182
r"""
@@ -192,8 +197,11 @@ def plot_bloch_vector_evolution(
192197
n_samples: int, optional
193198
The number of time points to be sampled.
194199
cmap: matplotlib colormap, optional
195-
The colormap for the trajectory.
196-
show**: bool, optional
200+
The colormap for the trajectory. Requires ``matplotlib >= 3.3.0``.
201+
add_cbar: bool, optional
202+
Add a colorbar encoding the time evolution to the figure.
203+
Default is false.
204+
show: bool, optional
197205
Whether to show the sphere (by calling :code:`b.make_sphere()`).
198206
return_Bloch: bool, optional
199207
Whether to return the :class:`qutip.bloch.Bloch` instance
@@ -220,12 +228,17 @@ def plot_bloch_vector_evolution(
220228
raise ValueError('Plotting Bloch sphere evolution only implemented for one-qubit case!')
221229

222230
# Parse default arguments
231+
figsize = bloch_kwargs.pop('figsize', [5, 5])
232+
view = bloch_kwargs.pop('view', [-60, 30])
223233
if b is None:
224-
figsize = bloch_kwargs.pop('figsize', [5, 5])
225-
view = bloch_kwargs.pop('view', [-60, 30])
226234
fig = plt.figure(figsize=figsize)
227235
axes = fig.add_subplot(projection='3d', azim=view[0], elev=view[1])
228236
b = init_bloch_sphere(fig=fig, axes=axes, **bloch_kwargs)
237+
else:
238+
if b.fig is None:
239+
b.fig = plt.figure(figsize=figsize)
240+
if b.axes is None:
241+
b.axes = b.fig.add_subplot(projection='3d', azim=view[0], elev=view[1])
229242

230243
if n_samples is None:
231244
# At least 100, at most 5000 points, default 10 points per smallest
@@ -235,21 +248,27 @@ def plot_bloch_vector_evolution(
235248
times = np.linspace(pulse.t[0], pulse.tau, n_samples)
236249
propagators = pulse.propagator_at_arb_t(times)
237250
points = get_bloch_vector(get_states_from_prop(propagators, psi0))
238-
b.add_points(points, meth='l')
239-
240-
# The following enables a color gradient for the trajectory, but only works
241-
# by patching matplotlib, see
242-
# https://github.com/matplotlib/matplotlib/issues/17755
243-
# points = get_bloch_vector(get_states_from_prop(propagators, psi0)).T.reshape(-1, 1, 3)
244-
# points[:, :, 1] *= -1 # qutip convention
245-
# segments = np.concatenate([points[:-1], points[1:]], axis=1)
246251

247-
# if cmap is None:
248-
# cmap = plt.get_cmap('winter')
249-
250-
# colors = cmap(np.linspace(0, 1, n_samples - 1))
251-
# lc = collections.LineCollection(segments[:, :, :2], colors=colors)
252-
# b.axes.add_collection3d(lc, zdir='z', zs=segments[:, :, 2])
252+
if version.parse(matplotlib.__version__) < version.parse('3.3.0'):
253+
# Colored trajectory not available.
254+
b.add_points(points, meth='l')
255+
else:
256+
points = points.T.reshape(-1, 1, 3)
257+
# Qutip convention: -x at +y, +y at +x
258+
copy = points.copy()
259+
points[:, :, 0] = copy[:, :, 1]
260+
points[:, :, 1] = -copy[:, :, 0]
261+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
262+
263+
cmap = plt.get_cmap(cmap)
264+
segment_colors = cmap(np.linspace(0, 1, n_samples - 1))
265+
lc = collections.LineCollection(segments[:, :, :2], colors=segment_colors)
266+
b.axes.add_collection3d(lc, zdir='z', zs=segments[:, :, 2])
267+
268+
if add_cbar:
269+
default_cbar_kwargs = dict(shrink=2/3, pad=0.05, label=r'$t$ ($\tau$)', ticks=[0, 1])
270+
cbar_kwargs = {**default_cbar_kwargs, **(cbar_kwargs or {})}
271+
b.fig.colorbar(cm.ScalarMappable(cmap=cmap), **cbar_kwargs)
253272

254273
if show:
255274
b.make_sphere()

tests/test_plotting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ def test_plot_bloch_vector_evolution(self):
363363

364364
b = plotting.plot_bloch_vector_evolution(complicated_pulse)
365365

366+
# Test add_cbar kwarg
367+
b = plotting.plot_bloch_vector_evolution(simple_pulse, cmap='viridis', add_cbar=True)
368+
366369
# Check exceptions being raised
367370
with self.assertRaises(ValueError):
368371
plotting.plot_bloch_vector_evolution(two_qubit_pulse)

0 commit comments

Comments
 (0)