From e62177d82c6305403c29b90703beee5063a7a77e Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Wed, 6 Mar 2024 12:42:54 -0800 Subject: [PATCH] Improve controlled gate support (#164) Right now, a controlled gate is included in the supported gates as long as its control qubit count is supported. This change checks whether the base gate of the controlled gate (e.g. rx for crx) is supported as well. --- qiskit_braket_provider/providers/adapter.py | 78 ++++++++++++++------- tests/providers/test_adapter.py | 21 ++++-- 2 files changed, 67 insertions(+), 32 deletions(-) diff --git a/qiskit_braket_provider/providers/adapter.py b/qiskit_braket_provider/providers/adapter.py index a8f1014..57d63fa 100644 --- a/qiskit_braket_provider/providers/adapter.py +++ b/qiskit_braket_provider/providers/adapter.py @@ -69,10 +69,19 @@ } _CONTROLLED_GATES_BY_QUBIT_COUNT = { - 1: {"ch", "cs", "csdg", "csx", "crx", "cry", "crz", "ccz"}, - 3: {"c3sx"}, + 1: { + "ch": "h", + "cs": "s", + "csdg": "sdg", + "csx": "sx", + "crx": "rx", + "cry": "ry", + "crz": "rz", + "ccz": "cz", + }, + 3: {"c3sx": "sx"}, } -_ARBITRARY_CONTROLLED_GATES = {"mcx"} +_ARBITRARY_CONTROLLED_GATES = {"mcx": "cx"} _ADDITIONAL_U_GATES = {"u1", "u2", "u3"} @@ -122,16 +131,9 @@ } _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES: dict[str, Callable] = { - "ch": braket_gates.H, - "cs": braket_gates.S, - "csdg": braket_gates.Si, - "csx": braket_gates.V, - "ccz": braket_gates.CZ, - "c3sx": braket_gates.V, - "mcx": braket_gates.CNot, - "crx": braket_gates.Rx, - "cry": braket_gates.Ry, - "crz": braket_gates.Rz, + controlled_gate: _GATE_NAME_TO_BRAKET_GATE[base_gate] + for gate_map in _CONTROLLED_GATES_BY_QUBIT_COUNT.values() + for controlled_gate, base_gate in gate_map.items() } _TRANSLATABLE_QISKIT_GATE_NAMES = ( @@ -197,23 +199,26 @@ def gateset_from_properties(properties: OpenQASMDeviceActionProperties) -> set[s for op in properties.supportedOperations if op.lower() in _BRAKET_TO_QISKIT_NAMES } + if "u" in gateset: + gateset.update(_ADDITIONAL_U_GATES) max_control = 0 for modifier in properties.supportedModifiers: if isinstance(modifier, Control): max_control = modifier.max_qubits break - gateset.update(_get_controlled_gateset(max_control)) - if "u" in gateset: - gateset.update(_ADDITIONAL_U_GATES) + gateset.update(_get_controlled_gateset(gateset, max_control)) return gateset -def _get_controlled_gateset(max_qubits: Optional[int] = None) -> set[str]: +def _get_controlled_gateset( + base_gateset: set[str], max_qubits: Optional[int] = None +) -> set[str]: """Returns the Qiskit gates expressible as controlled versions of existing Braket gates This set can be filtered by the maximum number of control qubits. Args: + base_gateset (set[str]): The base (without control modifiers) gates supported max_qubits (Optional[int]): The maximum number of control qubits that can be used to express the Qiskit gate as a controlled Braket gate. If `None`, then there is no limit to the number of control qubits. Default: `None`. @@ -222,11 +227,30 @@ def _get_controlled_gateset(max_qubits: Optional[int] = None) -> set[str]: set[str]: The names of the controlled gates. """ if max_qubits is None: - gateset = set().union(*[g for _, g in _CONTROLLED_GATES_BY_QUBIT_COUNT.items()]) + gateset = set().union( + [ + controlled_gate + for gate_map in _CONTROLLED_GATES_BY_QUBIT_COUNT.values() + for controlled_gate, base_gate in gate_map.items() + if base_gate in base_gateset + ] + ) + gateset.update( + [ + controlled_gate + for controlled_gate, base_gate in _ARBITRARY_CONTROLLED_GATES.items() + if base_gate in base_gateset + ] + ) gateset.update(_ARBITRARY_CONTROLLED_GATES) return gateset return set().union( - *[g for q, g in _CONTROLLED_GATES_BY_QUBIT_COUNT.items() if q <= max_qubits] + [ + controlled_gate + for control_count, gate_map in _CONTROLLED_GATES_BY_QUBIT_COUNT.items() + for controlled_gate, base_gate in gate_map.items() + if control_count <= max_qubits and base_gate in base_gateset + ] ) @@ -453,13 +477,15 @@ def to_braket( qubit_indices = [circuit.find_bit(qubit).index for qubit in qubits] params = _create_free_parameters(operation) if gate_name in _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES: - gate = _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES[gate_name](*params) - gate_qubit_count = gate.qubit_count - braket_circuit += Instruction( - operator=gate, - target=qubit_indices[-gate_qubit_count:], - control=qubit_indices[:-gate_qubit_count], - ) + for gate in _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES[gate_name]( + *params + ): + gate_qubit_count = gate.qubit_count + braket_circuit += Instruction( + operator=gate, + target=qubit_indices[-gate_qubit_count:], + control=qubit_indices[:-gate_qubit_count], + ) else: for gate in _GATE_NAME_TO_BRAKET_GATE[gate_name](*params): braket_circuit += Instruction( diff --git a/tests/providers/test_adapter.py b/tests/providers/test_adapter.py index 5838744..ed0c14b 100644 --- a/tests/providers/test_adapter.py +++ b/tests/providers/test_adapter.py @@ -440,15 +440,24 @@ def test_invalid_ctrl_state(self, mock_transpile): def test_get_controlled_gateset(self): """Tests that the correct controlled gateset is returned for all maximum qubit counts.""" + full_gateset = {"h", "s", "sdg", "sx", "rx", "ry", "rz", "cz"} + restricted_gateset = {"rx", "cx", "sx"} max1 = {"ch", "cs", "csdg", "csx", "crx", "cry", "crz", "ccz"} max3 = max1.union({"c3sx"}) unlimited = max3.union({"mcx"}) - assert _get_controlled_gateset(0) == set() - assert _get_controlled_gateset(1) == max1 - assert _get_controlled_gateset(2) == max1 - assert _get_controlled_gateset(3) == max3 - assert _get_controlled_gateset(4) == max3 - assert _get_controlled_gateset() == unlimited + assert _get_controlled_gateset(full_gateset, 0) == set() + assert _get_controlled_gateset(full_gateset, 1) == max1 + assert _get_controlled_gateset(full_gateset, 2) == max1 + assert _get_controlled_gateset(full_gateset, 3) == max3 + assert _get_controlled_gateset(full_gateset, 4) == max3 + assert _get_controlled_gateset(full_gateset) == unlimited + assert _get_controlled_gateset(restricted_gateset, 3) == {"crx", "csx", "c3sx"} + assert _get_controlled_gateset(restricted_gateset) == { + "crx", + "csx", + "c3sx", + "mcx", + } class TestFromBraket(TestCase):