Skip to content

Colored Bloch sphere trajectory #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 17, 2022
61 changes: 40 additions & 21 deletions filter_functions/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@
spectrum as an image.

"""
from packaging import version
from typing import Optional, Sequence, Union
from unittest import mock
from warnings import warn

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors, lines # , collections
from matplotlib import cm, collections, colors, lines
from mpl_toolkits import axes_grid1
from numpy import ndarray

Expand Down Expand Up @@ -170,8 +172,11 @@ def plot_bloch_vector_evolution(
psi0: Optional[State] = None,
b: Optional[qt.Bloch] = None,
n_samples: Optional[int] = None,
cmap: Optional[Colormap] = None,
show: bool = True, return_Bloch: bool = False,
cmap: Colormap = 'winter',
add_cbar: bool = False,
show: bool = True,
return_Bloch: bool = False,
cbar_kwargs: Optional[dict] = None,
**bloch_kwargs
) -> Union[None, qt.Bloch]:
r"""
Expand All @@ -192,8 +197,11 @@ def plot_bloch_vector_evolution(
n_samples: int, optional
The number of time points to be sampled.
cmap: matplotlib colormap, optional
The colormap for the trajectory.
show**: bool, optional
The colormap for the trajectory. Requires ``matplotlib >= 3.3.0``.
add_cbar: bool, optional
Add a colorbar encoding the time evolution to the figure.
Default is false.
show: bool, optional
Whether to show the sphere (by calling :code:`b.make_sphere()`).
return_Bloch: bool, optional
Whether to return the :class:`qutip.bloch.Bloch` instance
Expand All @@ -220,12 +228,17 @@ def plot_bloch_vector_evolution(
raise ValueError('Plotting Bloch sphere evolution only implemented for one-qubit case!')

# Parse default arguments
figsize = bloch_kwargs.pop('figsize', [5, 5])
view = bloch_kwargs.pop('view', [-60, 30])
if b is None:
figsize = bloch_kwargs.pop('figsize', [5, 5])
view = bloch_kwargs.pop('view', [-60, 30])
fig = plt.figure(figsize=figsize)
axes = fig.add_subplot(projection='3d', azim=view[0], elev=view[1])
b = init_bloch_sphere(fig=fig, axes=axes, **bloch_kwargs)
else:
if b.fig is None:
b.fig = plt.figure(figsize=figsize)
if b.axes is None:
b.axes = b.fig.add_subplot(projection='3d', azim=view[0], elev=view[1])

if n_samples is None:
# At least 100, at most 5000 points, default 10 points per smallest
Expand All @@ -235,21 +248,27 @@ def plot_bloch_vector_evolution(
times = np.linspace(pulse.t[0], pulse.tau, n_samples)
propagators = pulse.propagator_at_arb_t(times)
points = get_bloch_vector(get_states_from_prop(propagators, psi0))
b.add_points(points, meth='l')

# The following enables a color gradient for the trajectory, but only works
# by patching matplotlib, see
# https://github.com/matplotlib/matplotlib/issues/17755
# points = get_bloch_vector(get_states_from_prop(propagators, psi0)).T.reshape(-1, 1, 3)
# points[:, :, 1] *= -1 # qutip convention
# segments = np.concatenate([points[:-1], points[1:]], axis=1)

# if cmap is None:
# cmap = plt.get_cmap('winter')

# colors = cmap(np.linspace(0, 1, n_samples - 1))
# lc = collections.LineCollection(segments[:, :, :2], colors=colors)
# b.axes.add_collection3d(lc, zdir='z', zs=segments[:, :, 2])
if version.parse(matplotlib.__version__) < version.parse('3.3.0'):
# Colored trajectory not available.
b.add_points(points, meth='l')
else:
points = points.T.reshape(-1, 1, 3)
# Qutip convention: -x at +y, +y at +x
copy = points.copy()
points[:, :, 0] = copy[:, :, 1]
points[:, :, 1] = -copy[:, :, 0]
segments = np.concatenate([points[:-1], points[1:]], axis=1)

cmap = plt.get_cmap(cmap)
segment_colors = cmap(np.linspace(0, 1, n_samples - 1))
lc = collections.LineCollection(segments[:, :, :2], colors=segment_colors)
b.axes.add_collection3d(lc, zdir='z', zs=segments[:, :, 2])

if add_cbar:
default_cbar_kwargs = dict(shrink=2/3, pad=0.05, label=r'$t$ ($\tau$)', ticks=[0, 1])
cbar_kwargs = {**default_cbar_kwargs, **(cbar_kwargs or {})}
b.fig.colorbar(cm.ScalarMappable(cmap=cmap), **cbar_kwargs)

if show:
b.make_sphere()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def test_plot_bloch_vector_evolution(self):

b = plotting.plot_bloch_vector_evolution(complicated_pulse)

# Test add_cbar kwarg
b = plotting.plot_bloch_vector_evolution(simple_pulse, cmap='viridis', add_cbar=True)

# Check exceptions being raised
with self.assertRaises(ValueError):
plotting.plot_bloch_vector_evolution(two_qubit_pulse)
Expand Down