Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom gate defintions in QASM parser #6917

Merged
merged 21 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cirq-core/cirq/circuits/qasm_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,10 @@ def on_stuck(bad_op):
if should_annotate:
output_line_gap(1)
if isinstance(main_op, ops.GateOperation):
x = str(main_op.gate).replace('\n', '\n //')
x = str(main_op.gate).replace('\n', '\n// ')
output(f'// Gate: {x!s}\n')
else:
x = str(main_op).replace('\n', '\n //')
x = str(main_op).replace('\n', '\n// ')
output(f'// Operation: {x!s}\n')

for qasm in qasms:
Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/contrib/qasm_import/_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self):
'creg': 'CREG',
'measure': 'MEASURE',
'reset': 'RESET',
'gate': 'GATE',
'if': 'IF',
'->': 'ARROW',
'==': 'EQ',
Expand Down Expand Up @@ -120,6 +121,10 @@ def t_RESET(self, t):
r"""reset"""
return t

def t_GATE(self, t):
r"""gate"""
return t

def t_IF(self, t):
r"""if"""
return t
Expand Down
68 changes: 68 additions & 0 deletions cirq-core/cirq/contrib/qasm_import/_lexer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,74 @@ def test_creg():
assert token.value == ";"


def test_custom_gate():
lexer = QasmLexer()
lexer.input('gate name(param1,param2) q1, q2 {X(q1)}')
token = lexer.token()
assert token.type == "GATE"
assert token.value == "gate"

token = lexer.token()
assert token.type == "ID"
assert token.value == "name"

token = lexer.token()
assert token.type == "("
assert token.value == "("

token = lexer.token()
assert token.type == "ID"
assert token.value == "param1"

token = lexer.token()
assert token.type == ","
assert token.value == ","

token = lexer.token()
assert token.type == "ID"
assert token.value == "param2"

token = lexer.token()
assert token.type == ")"
assert token.value == ")"

token = lexer.token()
assert token.type == "ID"
assert token.value == "q1"

token = lexer.token()
assert token.type == ","
assert token.value == ","

token = lexer.token()
assert token.type == "ID"
assert token.value == "q2"

token = lexer.token()
assert token.type == "{"
assert token.value == "{"

token = lexer.token()
assert token.type == "ID"
assert token.value == "X"

token = lexer.token()
assert token.type == "("
assert token.value == "("

token = lexer.token()
assert token.type == "ID"
assert token.value == "q1"

token = lexer.token()
assert token.type == ")"
assert token.value == ")"

token = lexer.token()
assert token.type == "}"
assert token.value == "}"


def test_error():
lexer = QasmLexer()
lexer.input('θ')
Expand Down
131 changes: 122 additions & 9 deletions cirq-core/cirq/contrib/qasm_import/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import functools
import operator
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Union, TYPE_CHECKING
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
TYPE_CHECKING,
)

import numpy as np
import sympy
from ply import yacc

from cirq import ops, Circuit, NamedQubit, CX
from cirq import ops, value, Circuit, CircuitOperation, CX, FrozenCircuit, NamedQubit
from cirq.circuits.qasm_output import QasmUGate
from cirq.contrib.qasm_import._lexer import QasmLexer
from cirq.contrib.qasm_import.exception import QasmException
Expand Down Expand Up @@ -87,15 +100,15 @@ def _validate_args(self, args: List[List[ops.Qid]], lineno: int):
f"got: {len(args)}, at line {lineno}"
)

def _validate_params(self, params: List[float], lineno: int):
def _validate_params(self, params: List[value.TParamVal], lineno: int):
if len(params) != self.num_params:
raise QasmException(
f"{self.qasm_gate} takes {self.num_params} parameter(s), "
f"got: {len(params)}, at line {lineno}"
)

def on(
self, params: List[float], args: List[List[ops.Qid]], lineno: int
self, params: List[value.TParamVal], args: List[List[ops.Qid]], lineno: int
) -> Iterable[ops.Operation]:
self._validate_args(args, lineno)
self._validate_params(params, lineno)
Expand Down Expand Up @@ -132,6 +145,28 @@ def on(
yield final_gate.on(*qubits)


@dataclasses.dataclass
class CustomGate:
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
name: str
circuit: FrozenCircuit
params: Tuple[str, ...]
qubits: Tuple[ops.Qid, ...]

def on(
self, params: List[value.TParamVal], args: List[List[ops.Qid]], lineno: int
) -> ops.Operation:
if len(params) != len(self.params):
raise QasmException(f'Wrong number of params for "{self.name}" at line {lineno}')
qubits = [q for qs in args for q in qs]
if len(qubits) != len(self.qubits):
raise QasmException(f'Wrong number of qregs for "{self.name}" at line {lineno}')
return CircuitOperation(
self.circuit,
param_resolver={k: v for k, v in zip(self.params, params)},
qubit_map={k: v for k, v in zip(self.qubits, qubits)},
)


class QasmParser:
"""Parser for QASM strings.

Expand All @@ -146,6 +181,10 @@ def __init__(self) -> None:
self.circuit = Circuit()
self.qregs: Dict[str, int] = {}
self.cregs: Dict[str, int] = {}
self.all_gates: Dict[str, Union[CustomGate, QasmGateStatement]] = {**self.basic_gates}
self.custom_gate_scoped_params: Set[str] = set()
self.custom_gate_scoped_qubits: Dict[str, ops.Qid] = {}
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
self.in_custom_gate_scope = False
self.qelibinc = False
self.lexer = QasmLexer()
self.supported_format = False
Expand Down Expand Up @@ -270,8 +309,6 @@ def __init__(self) -> None:
'tdg': QasmGateStatement(qasm_gate='tdg', num_params=0, num_args=1, cirq_gate=ops.T**-1),
}

all_gates = {**basic_gates, **qelib_gates}

tokens = QasmLexer.tokens
start = 'start'

Expand All @@ -296,11 +333,13 @@ def p_qasm_no_format_specified_error(self, p):
def p_qasm_include(self, p):
"""qasm : qasm QELIBINC"""
self.qelibinc = True
self.all_gates |= self.qelib_gates
p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit)

def p_qasm_include_stdgates(self, p):
"""qasm : qasm STDGATESINC"""
self.qelibinc = True
self.all_gates |= self.qelib_gates
p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit)

def p_qasm_circuit(self, p):
Expand Down Expand Up @@ -338,6 +377,10 @@ def p_circuit_empty(self, p):
"""circuit : empty"""
p[0] = self.circuit

def p_circuit_gate_def(self, p):
"""circuit : gate_def"""
p[0] = self.circuit

# qreg and creg

def p_new_reg(self, p):
Expand Down Expand Up @@ -382,9 +425,9 @@ def p_gate_op_with_params(self, p):
self._resolve_gate_operation(args=p[5], gate=p[1], p=p, params=p[3])

def _resolve_gate_operation(
self, args: List[List[ops.Qid]], gate: str, p: Any, params: List[float]
self, args: List[List[ops.Qid]], gate: str, p: Any, params: List[value.TParamVal]
):
gate_set = self.basic_gates if not self.qelibinc else self.all_gates
gate_set = self.all_gates
if gate not in gate_set.keys():
tip = ", did you forget to include qelib1.inc?" if not self.qelibinc else ""
msg = f'Unknown gate "{gate}" at line {p.lineno(1)}{tip}'
Expand All @@ -404,14 +447,23 @@ def p_params_single(self, p):
p[0] = [p[1]]

# expr : term
# | func '(' expression ')' """
# | ID
# | func '(' expression ')'
# | binary_op
# | unary_op

def p_expr_term(self, p):
"""expr : term"""
p[0] = p[1]

def p_expr_identifier(self, p):
"""expr : ID"""
if not self.in_custom_gate_scope:
raise QasmException(f'Parameter "{p[1]}" in line {p.lineno(1)} not supported')
if p[1] not in self.custom_gate_scoped_params:
raise QasmException(f'Undefined parameter "{p[1]}" in line {p.lineno(1)}')
p[0] = sympy.Symbol(p[1])

def p_expr_parens(self, p):
"""expr : '(' expr ')'"""
p[0] = p[2]
Expand Down Expand Up @@ -464,6 +516,11 @@ def p_args_single(self, p):
def p_quantum_arg_register(self, p):
"""qarg : ID"""
reg = p[1]
if self.in_custom_gate_scope:
if reg not in self.custom_gate_scoped_qubits:
raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}')
p[0] = [self.custom_gate_scoped_qubits[reg]]
return
if reg not in self.qregs.keys():
raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}')
qubits = []
Expand Down Expand Up @@ -492,6 +549,8 @@ def p_quantum_arg_bit(self, p):
"""qarg : ID '[' NATURAL_NUMBER ']'"""
reg = p[1]
idx = p[3]
if self.in_custom_gate_scope:
raise QasmException(f'Unsupported indexed qreg "{reg}[{idx}]" at line {p.lineno(1)}')
arg_name = self.make_name(idx, reg)
if reg not in self.qregs.keys():
raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}')
Expand Down Expand Up @@ -570,6 +629,60 @@ def p_if(self, p):
ops.ClassicallyControlledOperation(conditions=conditions, sub_operation=tuple(p[7])[0])
]

def p_gate_params_multiple(self, p):
"""gate_params : ID ',' gate_params"""
self.p_gate_params_single(p)
p[0] += p[3]

def p_gate_params_single(self, p):
"""gate_params : ID"""
self.in_custom_gate_scope = True
self.custom_gate_scoped_params.add(p[1])
p[0] = [p[1]]

def p_gate_qubits_multiple(self, p):
"""gate_qubits : ID ',' gate_qubits"""
self.p_gate_qubits_single(p)
p[0] += p[3]

def p_gate_qubits_single(self, p):
"""gate_qubits : ID"""
self.in_custom_gate_scope = True
q = NamedQubit(p[1])
self.custom_gate_scoped_qubits[p[1]] = q
p[0] = [q]

def p_gate_ops(self, p):
"""gate_ops : gate_op gate_ops"""
p[0] = [p[1]] + p[2]

def p_gate_ops_empty(self, p):
"""gate_ops : empty"""
self.in_custom_gate_scope = True
p[0] = []

def p_gate_def_parameterized(self, p):
"""gate_def : GATE ID '(' gate_params ')' gate_qubits '{' gate_ops '}'"""
self._gate_def(p, has_params=True)

def p_gate_def(self, p):
"""gate_def : GATE ID gate_qubits '{' gate_ops '}'"""
self._gate_def(p, has_params=False)

def _gate_def(self, p: List[Any], *, has_params: bool):
name = p[2]
gate_params = tuple(p[4]) if has_params else ()
offset = 3 if has_params else 0
gate_qubits = tuple(p[3 + offset])
gate_ops = p[5 + offset]
circuit = Circuit(gate_ops).freeze()
gate_def = CustomGate(name, circuit, gate_params, gate_qubits)
self.all_gates[name] = gate_def
self.custom_gate_scoped_params.clear()
self.custom_gate_scoped_qubits.clear()
self.in_custom_gate_scope = False
p[0] = gate_def

def p_error(self, p):
if p is None:
raise QasmException('Unexpected end of file')
Expand Down
Loading