Skip to content

Commit

Permalink
Improve unit test coverage to 100% (qiskit-community#179)
Browse files Browse the repository at this point in the history
* Improve coverage to 100%

* Improve coverage to 100%

* Add tests

* Improve coverage to 100%

* Improve coverage to 100%

* Improve coverage to 100%

* Fix region retrieval

* Lint import

* Fix black

* Fix pylint

* Fix pylint

* Fix pylint

* Fix pylint

* Fix pylint

* Bump code coverage condition

* improve tests

* improve tests
  • Loading branch information
WingCode authored Jun 6, 2024
1 parent 0bb0086 commit 99bad32
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 5 deletions.
35 changes: 35 additions & 0 deletions tests/providers/mocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Mocks for testing."""

import copy
import enum
import uuid
from collections import Counter
from typing import Dict
Expand All @@ -12,6 +13,8 @@
from braket.tasks import GateModelQuantumTaskResult
from braket.tasks.local_quantum_task import LocalQuantumTask

from qiskit_braket_provider.providers.braket_backend import BraketBackend

RIGETTI_ARN = "arn:aws:braket:::device/qpu/rigetti/Aspen-10"
RIGETTI_ASPEN_ARN = "arn:aws:braket:::device/qpu/rigetti/Aspen-M-3"
SV1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
Expand Down Expand Up @@ -158,3 +161,35 @@
)

MOCK_LOCAL_QUANTUM_TASK = LocalQuantumTask(MOCK_GATE_MODEL_QUANTUM_TASK_RESULT)


class MockBraketBackend(BraketBackend):
"""
Mock class for BraketBackend.
"""

@property
def target(self):
pass

@property
def max_circuits(self):
pass

@classmethod
def _default_options(cls):
pass

def run(self, run_input, **kwargs):
"""
Mock method for run.
"""
pass


class MockMeasLevelEnum(enum.Enum):
"""
Mock class for MeasLevelEnum.
"""

LEVEL_TWO = 2
24 changes: 24 additions & 0 deletions tests/providers/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ def test_measure(self):

self.assertEqual(braket_circuit, expected_braket_circuit)

def test_reset(self):
"""Tests if NotImplementedError is raised for reset operation."""

qiskit_circuit = QuantumCircuit(1, 1)
qiskit_circuit.reset(0)

with self.assertRaises(NotImplementedError):
to_braket(qiskit_circuit)

def test_measure_different_indices(self):
"""
Tests the translation of a measure instruction.
Expand Down Expand Up @@ -657,6 +666,21 @@ def test_power(self):

self.assertEqual(qiskit_circuit, expected_qiskit_circuit)

def test_unsupported_braket_gate(self):
"""Tests if TypeError is raised for unsupported Braket gate."""

gate = getattr(Gate, "CNot")
op = gate()
instr = Instruction(op, range(2))
circuit = Circuit().add_instruction(instr)

with self.assertRaises(TypeError):
with patch.dict(
"qiskit_braket_provider.providers.adapter._GATE_NAME_TO_QISKIT_GATE",
{"cnot": None},
):
to_qiskit(circuit)

def test_measure_subset(self):
"""Tests the measure instruction conversion from braket to qiskit"""
braket_circuit = Circuit().h(0).cnot(0, 1).measure(0)
Expand Down
189 changes: 189 additions & 0 deletions tests/providers/test_braket_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@

import numpy as np
from botocore import errorfactory
from braket.aws import AwsDevice, AwsQuantumTaskBatch
from braket.aws.queue_information import QueueDepthInfo, QueueType
from braket.device_schema import DeviceActionType
from braket.tasks.local_quantum_task import LocalQuantumTask
from qiskit import QuantumCircuit, transpile
from qiskit.circuit import Instruction as QiskitInstruction
from qiskit.circuit.library import TwoLocal
from qiskit.circuit.random import random_circuit
from qiskit.primitives import BackendEstimator
Expand All @@ -21,9 +25,12 @@
from qiskit_braket_provider import AWSBraketProvider, exception, version
from qiskit_braket_provider.providers import BraketAwsBackend, BraketLocalBackend
from qiskit_braket_provider.providers.adapter import aws_device_to_target
from qiskit_braket_provider.providers.braket_backend import AWSBraketBackend
from tests.providers.mocks import (
RIGETTI_MOCK_GATE_MODEL_QPU_CAPABILITIES,
RIGETTI_MOCK_M_3_QPU_CAPABILITIES,
MockBraketBackend,
MockMeasLevelEnum,
)


Expand Down Expand Up @@ -53,6 +60,20 @@ def combine_dicts(
return combined_dict


class TestBraketBackend(TestCase):
"""Test class for BraketBackend."""

def test_repr(self):
"""Test the repr method of BraketBackend."""
backend = BraketLocalBackend(name="default")
self.assertEqual(repr(backend), "BraketBackend[default]")

def test_invalid_device(self):
"""Test the device method of BraketBackend."""
with self.assertRaises(NotImplementedError):
_ = MockBraketBackend(name="default")._device


class TestBraketAwsBackend(TestCase):
"""Tests BraketBackend."""

Expand All @@ -66,6 +87,12 @@ def test_device_backend(self):
self.assertIsNone(backend.max_circuits)
user_agent = f"QiskitBraketProvider/" f"{version.__version__}"
device.aws_session.add_braket_user_agent.assert_called_with(user_agent)
with self.assertRaises(NotImplementedError):
backend.dtm()
with self.assertRaises(NotImplementedError):
backend.meas_map()
with self.assertRaises(NotImplementedError):
backend.qubit_properties(0)
with self.assertRaises(NotImplementedError):
backend.drive_channel(0)
with self.assertRaises(NotImplementedError):
Expand All @@ -75,12 +102,26 @@ def test_device_backend(self):
with self.assertRaises(NotImplementedError):
backend.control_channel([0, 1])

def test_invalid_identifiers(self):
"""Test the invalid identifiers of BraketAwsBackend."""
with self.assertRaises(ValueError):
BraketAwsBackend()

with self.assertRaises(ValueError):
BraketAwsBackend(arn="some_arn", device="some_device")

def test_local_backend(self):
"""Tests local backend."""
backend = BraketLocalBackend(name="default")
self.assertTrue(backend)
self.assertIsInstance(backend.target, Target)
self.assertIsNone(backend.max_circuits)
with self.assertRaises(NotImplementedError):
backend.dtm()
with self.assertRaises(NotImplementedError):
backend.meas_map()
with self.assertRaises(NotImplementedError):
backend.qubit_properties(0)
with self.assertRaises(NotImplementedError):
backend.drive_channel(0)
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -140,6 +181,77 @@ def test_local_backend_circuit_shots0(self):
)
)

def test_deprecation_warning_on_init(self):
"""Test that a deprecation warning is raised when initializing AWSBraketBackend"""
mock_aws_device = Mock(spec=AwsDevice)
mock_aws_device.properties = RIGETTI_MOCK_GATE_MODEL_QPU_CAPABILITIES

with self.assertWarns(DeprecationWarning):
AWSBraketBackend(device=mock_aws_device)

def test_deprecation_warning_on_subclass(self):
"""Test that a deprecation warning is raised when subclassing AWSBraketBackend"""

with self.assertWarns(DeprecationWarning):

class SubclassAWSBraketBackend(
AWSBraketBackend
): # pylint: disable=unused-variable
"""A subclass of AWSBraketBackend for testing purposes"""

pass

def test_run_multiple_circuits(self):
"""Tests run with multiple circuits"""
device = Mock()
device.properties = RIGETTI_MOCK_GATE_MODEL_QPU_CAPABILITIES
backend = BraketAwsBackend(device=device)
mock_task_1 = Mock(spec=LocalQuantumTask)
mock_task_1.id = "0"
mock_task_2 = Mock(spec=LocalQuantumTask)
mock_task_2.id = "1"
mock_batch = Mock(spec=AwsQuantumTaskBatch)
mock_batch.tasks = [mock_task_1, mock_task_2]
backend._device.run_batch = Mock(return_value=mock_batch)
circuit = QuantumCircuit(1)
circuit.h(0)

backend.run([circuit, circuit], shots=0, meas_level=2)

def test_run_invalid_run_input(self):
"""Tests run with invalid input to run"""
device = Mock()
device.properties = RIGETTI_MOCK_GATE_MODEL_QPU_CAPABILITIES
backend = BraketAwsBackend(device=device)
with self.assertRaises(exception.QiskitBraketException):
backend.run(1, shots=0)

@patch(
"braket.devices.LocalSimulator.run",
side_effect=[
Mock(return_value=Mock(id="0", spec=LocalQuantumTask)),
Exception("Mock exception"),
],
)
def test_local_backend_run_exception(self, braket_devices_run):
"""Tests local backend with exception thrown during second run"""
backend = BraketLocalBackend(name="default")

circuit = QuantumCircuit(1)
circuit.h(0)

with self.assertRaises(Exception):
backend.run([circuit, circuit], shots=0) # First run should pass
braket_devices_run.assert_called()

def test_meas_level_enum(self):
"""Check that enum meas level can be successfully accessed without error"""
backend = BraketLocalBackend(name="default")
circuit = QuantumCircuit(1, 1)
circuit.h(0)
circuit.measure(0, 0)
backend.run(circuit, shots=10, meas_level=MockMeasLevelEnum.LEVEL_TWO)

def test_meas_level_2(self):
"""Check that there's no error for asking for classified measurement results."""
backend = BraketLocalBackend(name="default")
Expand Down Expand Up @@ -226,6 +338,30 @@ def test_random_circuits(self):
f"and absolute difference {abs_diff}. Original values {values}",
)

@patch("qiskit_braket_provider.providers.braket_backend.AwsQuantumTask")
@patch("qiskit_braket_provider.providers.braket_backend.BraketQuantumTask")
def test_retrieve_job_task_ids(
self, mock_braket_quantum_task, mock_aws_quantum_task
):
"""Test method for retrieving job task IDs."""
device = Mock()
device.properties = RIGETTI_MOCK_GATE_MODEL_QPU_CAPABILITIES
backend = BraketAwsBackend(device=device)
task_id = "task1;task2;task3"
expected_task_ids = task_id.split(";")

backend.retrieve_job(task_id)

# Assert
mock_aws_quantum_task.assert_any_call(arn=expected_task_ids[0])
mock_aws_quantum_task.assert_any_call(arn=expected_task_ids[1])
mock_aws_quantum_task.assert_any_call(arn=expected_task_ids[2])
mock_braket_quantum_task.assert_called_once_with(
task_id=task_id,
backend=backend,
tasks=[mock_aws_quantum_task(arn=task_id) for task_id in expected_task_ids],
)

@unittest.skip("Call to external resources.")
def test_retrieve_job(self):
"""Tests retrieve task by id."""
Expand Down Expand Up @@ -319,3 +455,56 @@ def test_target(self):
self.assertEqual(len(target.operations), 2)
self.assertEqual(len(target.instructions), 60)
self.assertIn("Target for Amazon Braket QPU", target.description)

def test_target_invalid_device(self):
"""Tests target."""
mock_device = Mock()
mock_device.properties = None

with self.assertRaises(exception.QiskitBraketException):
aws_device_to_target(mock_device)

def test_fully_connected(self):
"""Tests if instruction_props is correctly populated for fully connected topology."""
mock_device = Mock()
mock_device.properties = RIGETTI_MOCK_GATE_MODEL_QPU_CAPABILITIES.copy(
deep=True
)
mock_device.properties.paradigm.connectivity.fullyConnected = True
mock_device.properties.paradigm.qubitCount = 2
mock_device.properties.action.get(
DeviceActionType.OPENQASM
).supportedOperations = ["CNOT"]

instruction_props = aws_device_to_target(mock_device)

cx_instruction = QiskitInstruction(
name="cx", num_qubits=2, num_clbits=0, params=[]
)
measure_instruction = QiskitInstruction(
name="measure", num_qubits=1, num_clbits=1, params=[]
)

expected_instruction_props = [
(cx_instruction, (0, 1)),
(cx_instruction, (1, 0)),
(measure_instruction, (0,)),
(measure_instruction, (1,)),
]
for index, instruction in enumerate(instruction_props.instructions):
self.assertEqual(
instruction[0].num_qubits,
expected_instruction_props[index][0].num_qubits,
)
self.assertEqual(
instruction[0].num_clbits,
expected_instruction_props[index][0].num_clbits,
)
self.assertEqual(
instruction[0].params, expected_instruction_props[index][0].params
)
self.assertEqual(
instruction[0].name, expected_instruction_props[index][0].name
)

self.assertEqual(instruction[1], expected_instruction_props[index][1])
29 changes: 29 additions & 0 deletions tests/providers/test_braket_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Tests for Braket jobs"""

from unittest import TestCase

from qiskit_braket_provider.providers.braket_job import AmazonBraketTask, AWSBraketJob


class TestAmazonBraketTask(TestCase):
"""Tests Amazon Braket Task"""

def test_deprecation_warning_on_init(self):
"""Test to check if a deprecation warning is raised when initializing AmazonBraketTask"""
with self.assertWarns(DeprecationWarning):

class SubAmazonBraketTask(
AmazonBraketTask
): # pylint: disable=unused-variable
"""Subclass of AmazonBraketTask for testing"""


class TestAWSBraketJob(TestCase):
"""Tests Amazon Braket Job"""

def test_deprecation_warning_on_init(self):
"""Test to check if a deprecation warning is raised when initializing AWSBraketJob"""
with self.assertWarns(DeprecationWarning):

class SubAwsBraketJob(AWSBraketJob): # pylint: disable=unused-variable
"""Subclass of AWSBraketJob for testing"""
Loading

0 comments on commit 99bad32

Please sign in to comment.