diff --git a/cirq-core/cirq/experiments/single_qubit_readout_calibration.py b/cirq-core/cirq/experiments/single_qubit_readout_calibration.py index 0e60653ba15..34c18fcbdb6 100644 --- a/cirq-core/cirq/experiments/single_qubit_readout_calibration.py +++ b/cirq-core/cirq/experiments/single_qubit_readout_calibration.py @@ -18,6 +18,10 @@ import sympy import numpy as np +import matplotlib.pyplot as plt +import cirq.vis.heatmap as cirq_heatmap +import cirq.vis.histogram as cirq_histogram +from cirq.devices import grid_qubit from cirq import circuits, ops, study if TYPE_CHECKING: @@ -51,6 +55,124 @@ def _json_dict_(self) -> Dict[str, Any]: 'timestamp': self.timestamp, } + def plot_heatmap( + self, + axs: Optional[tuple[plt.Axes, plt.Axes]] = None, + annotation_format: str = '0.1%', + **plot_kwargs: Any, + ) -> tuple[plt.Axes, plt.Axes]: + """Plot a heatmap of the readout errors. If qubits are not cirq.GridQubits, throws an error. + + Args: + axs: a tuple of the plt.Axes to plot on. If not given, a new figure is created, + plotted on, and shown. + annotation_format: The format string for the numbers in the heatmap. + **plot_kwargs: Arguments to be passed to 'cirq.Heatmap.plot()'. + Returns: + The two plt.Axes containing the plot. + + Raises: + ValueError: axs does not contain two plt.Axes + TypeError: qubits are not cirq.GridQubits + """ + + if axs is None: + _, axs = plt.subplots(1, 2, dpi=200, facecolor='white', figsize=(12, 4)) + + else: + if ( + not isinstance(axs, (tuple, list, np.ndarray)) + or len(axs) != 2 + or type(axs[0]) != plt.Axes + or type(axs[1]) != plt.Axes + ): # pragma: no cover + raise ValueError('axs should be a length-2 tuple of plt.Axes') # pragma: no cover + for ax, title, data in zip( + axs, + ['$|0\\rangle$ errors', '$|1\\rangle$ errors'], + [self.zero_state_errors, self.one_state_errors], + ): + data_with_grid_qubit_keys = {} + for qubit in data: + if type(qubit) != grid_qubit.GridQubit: + raise TypeError(f'{qubit} must be of type cirq.GridQubit') # pragma: no cover + data_with_grid_qubit_keys[qubit] = data[qubit] # just for typecheck + _, _ = cirq_heatmap.Heatmap(data_with_grid_qubit_keys).plot( + ax, annotation_format=annotation_format, title=title, **plot_kwargs + ) + return axs[0], axs[1] + + def plot_integrated_histogram( + self, + ax: Optional[plt.Axes] = None, + cdf_on_x: bool = False, + axis_label: str = 'Readout error rate', + semilog: bool = True, + median_line: bool = True, + median_label: Optional[str] = 'median', + mean_line: bool = False, + mean_label: Optional[str] = 'mean', + show_zero: bool = False, + title: Optional[str] = None, + **kwargs, + ) -> plt.Axes: + """Plot the readout errors using cirq.integrated_histogram(). + + Args: + ax: The axis to plot on. If None, we generate one. + cdf_on_x: If True, flip the axes compared the above example. + axis_label: Label for x axis (y-axis if cdf_on_x is True). + semilog: If True, force the x-axis to be logarithmic. + median_line: If True, draw a vertical line on the median value. + median_label: If drawing median line, optional label for it. + mean_line: If True, draw a vertical line on the mean value. + mean_label: If drawing mean line, optional label for it. + title: Title of the plot. If None, we assign "N={len(data)}". + show_zero: If True, moves the step plot up by one unit by prepending 0 + to the data. + **kwargs: Kwargs to forward to `ax.step()`. Some examples are + color: Color of the line. + linestyle: Linestyle to use for the plot. + lw: linewidth for integrated histogram. + ms: marker size for a histogram trace. + Returns: + The axis that was plotted on. + """ + + ax = cirq_histogram.integrated_histogram( + data=self.zero_state_errors, + ax=ax, + cdf_on_x=cdf_on_x, + semilog=semilog, + median_line=median_line, + median_label=median_label, + mean_line=mean_line, + mean_label=mean_label, + show_zero=show_zero, + color='C0', + label='$|0\\rangle$ errors', + **kwargs, + ) + ax = cirq_histogram.integrated_histogram( + data=self.one_state_errors, + ax=ax, + cdf_on_x=cdf_on_x, + axis_label=axis_label, + semilog=semilog, + median_line=median_line, + median_label=median_label, + mean_line=mean_line, + mean_label=mean_label, + show_zero=show_zero, + title=title, + color='C1', + label='$|1\\rangle$ errors', + **kwargs, + ) + ax.legend(loc='best') + ax.set_ylabel('Percentile') + return ax + @classmethod def _from_json_dict_( cls, zero_state_errors, one_state_errors, repetitions, timestamp, **kwargs diff --git a/cirq-core/cirq/experiments/single_qubit_readout_calibration_test.py b/cirq-core/cirq/experiments/single_qubit_readout_calibration_test.py index 3ca5260a676..5fc63dcb3ec 100644 --- a/cirq-core/cirq/experiments/single_qubit_readout_calibration_test.py +++ b/cirq-core/cirq/experiments/single_qubit_readout_calibration_test.py @@ -87,7 +87,7 @@ def test_estimate_single_qubit_readout_errors_with_noise(): def test_estimate_parallel_readout_errors_no_noise(): - qubits = cirq.LineQubit.range(10) + qubits = [cirq.GridQubit(i, 0) for i in range(10)] sampler = cirq.Simulator() repetitions = 1000 result = cirq.estimate_parallel_single_qubit_readout_errors( @@ -97,6 +97,8 @@ def test_estimate_parallel_readout_errors_no_noise(): assert result.one_state_errors == {q: 0 for q in qubits} assert result.repetitions == repetitions assert isinstance(result.timestamp, float) + _ = result.plot_integrated_histogram() + _, _ = result.plot_heatmap() def test_estimate_parallel_readout_errors_all_zeros():