Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5f77cc3
fix requirements-dev.txt
andrijapau Oct 20, 2025
e738313
fix _grad.py
andrijapau Oct 20, 2025
a1189a0
fix operation.py
andrijapau Oct 20, 2025
02f09c5
fix collect_resource_ops
andrijapau Oct 20, 2025
ff095e5
fix param midmeasure
andrijapau Oct 20, 2025
da55dec
fix qasm interpreter
andrijapau Oct 20, 2025
3f63f0e
fix controlled
andrijapau Oct 20, 2025
43a0e9f
fix condition
andrijapau Oct 20, 2025
019369c
fix insert_ops
andrijapau Oct 20, 2025
0f34efe
fix while_loop
andrijapau Oct 20, 2025
d9504f2
fix typing_util
andrijapau Oct 20, 2025
27712e2
fix control flow
andrijapau Oct 20, 2025
c5fbd40
fix base_interpreter
andrijapau Oct 20, 2025
74ab082
fix merge amp emb
andrijapau Oct 20, 2025
bd1100c
fix transform_dispatch
andrijapau Oct 20, 2025
f6c6878
fix allocation
andrijapau Oct 20, 2025
ac2ae41
more fixes
andrijapau Oct 20, 2025
0319050
fix mottonen.py
andrijapau Oct 20, 2025
c09e14a
fix more
andrijapau Oct 20, 2025
f8eaf36
whoops
andrijapau Oct 20, 2025
75023c1
counts
andrijapau Oct 20, 2025
c11518c
default_gaussian.py
andrijapau Oct 20, 2025
4612661
final?
andrijapau Oct 20, 2025
d498b55
whoops
andrijapau Oct 20, 2025
64eb0b1
Merge branch 'master' into sc-101550/bump-pylint-black-isort
andrijapau Oct 21, 2025
e79415a
fix _grad.py
andrijapau Oct 21, 2025
1fbba22
more tests fixes
andrijapau Oct 21, 2025
88e6f11
more tests fixes
andrijapau Oct 21, 2025
74a4929
Trigger CI
andrijapau Oct 21, 2025
094cde1
Merge branch 'master' into sc-101550/bump-pylint-black-isort
andrijapau Oct 21, 2025
2c7f81a
fix data
andrijapau Oct 21, 2025
cb1e07d
Merge branch 'master' into sc-101550/bump-pylint-black-isort
andrijapau Oct 21, 2025
f4ec6d4
upgrade black
andrijapau Oct 21, 2025
8e5a306
missing whitespace
andrijapau Oct 21, 2025
8004590
cl
andrijapau Oct 21, 2025
b8fae5d
Merge branch 'master' into sc-101550/bump-pylint-black-isort
JerryChen97 Oct 22, 2025
b4076a3
make phi optional
andrijapau Oct 22, 2025
7d60524
Apply suggestion from @andrijapau
andrijapau Oct 24, 2025
9b839a7
Apply suggestion from @andrijapau
andrijapau Oct 24, 2025
c8836f4
Apply suggestion from @andrijapau
andrijapau Oct 24, 2025
39560eb
Apply suggestion from @andrijapau
andrijapau Oct 24, 2025
baa1f64
Apply suggestion from @andrijapau
andrijapau Oct 24, 2025
a2a5437
Apply suggestion from @andrijapau
andrijapau Oct 24, 2025
1e385c8
Apply suggestion from @andrijapau
andrijapau Oct 24, 2025
5bb0e65
Merge branch 'master' into sc-101550/bump-pylint-black-isort
andrijapau Oct 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
:orphan:

# Release 0.44.0-dev (development release)

<h3>New features since last release</h3>
Expand Down Expand Up @@ -148,6 +146,9 @@

<h3>Internal changes ⚙️</h3>

* Update versions for `pylint`, `isort` and `black` in `format.yml`
[(#8506)](https://github.com/PennyLaneAI/pennylane/pull/8506)

* Reclassifies `registers` as a tertiary module for use with tach.
[(#8513)](https://github.com/PennyLaneAI/pennylane/pull/8513)

Expand Down
8 changes: 4 additions & 4 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _get_grad_prim():
grad_prim.prim_type = "higher_order"

@grad_prim.def_impl
def _(*args, argnums, jaxpr, n_consts, method, h):
def _grad_def_impl(*args, argnums, jaxpr, n_consts, method, h):
if method or h: # pragma: no cover
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
consts = args[:n_consts]
Expand All @@ -63,7 +63,7 @@ def func(*inner_args):

# pylint: disable=unused-argument
@grad_prim.def_abstract_eval
def _(*args, argnums, jaxpr, n_consts, method, h):
def _grad_abstract_eval(*args, argnums, jaxpr, n_consts, method, h):
if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != ():
raise TypeError("Grad only applies to scalar-output functions. Try jacobian.")
return tuple(args[i + n_consts] for i in argnums)
Expand All @@ -90,7 +90,7 @@ def _get_jacobian_prim():
jacobian_prim.prim_type = "higher_order"

@jacobian_prim.def_impl
def _(*args, argnums, jaxpr, n_consts, method, h):
def _jacobian_def_impl(*args, argnums, jaxpr, n_consts, method, h):
if method or h: # pragma: no cover
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
consts = args[:n_consts]
Expand All @@ -103,7 +103,7 @@ def func(*inner_args):

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
def _(*args, argnums, jaxpr, n_consts, method, h):
def _jacobian_abstract_eval(*args, argnums, jaxpr, n_consts, method, h):
in_avals = tuple(args[i + n_consts] for i in argnums)
out_shapes = tuple(outvar.aval.shape for outvar in jaxpr.outvars)
return [
Expand Down
12 changes: 8 additions & 4 deletions pennylane/allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,28 @@ class AllocateState(StrEnum):
allocate_prim.multiple_results = True

@allocate_prim.def_impl
def _(*, num_wires, state: AllocateState = AllocateState.ZERO, restored=False):
def _allocate_primitive_impl(
*, num_wires, state: AllocateState = AllocateState.ZERO, restored=False
):
raise NotImplementedError("jaxpr containing qubit allocation cannot be executed.")

# pylint: disable=unused-argument
@allocate_prim.def_abstract_eval
def _(*, num_wires, state: AllocateState = AllocateState.ZERO, restored=False):
def _allocate_primitive_abstract_eval(
*, num_wires, state: AllocateState = AllocateState.ZERO, restored=False
):
return [jax.core.ShapedArray((), dtype=int) for _ in range(num_wires)]

deallocate_prim = QmlPrimitive("deallocate")
deallocate_prim.multiple_results = True

@deallocate_prim.def_impl
def _(*wires):
def _deallocate_primitive_impl(*wires):
raise NotImplementedError("jaxpr containing qubit deallocation cannot be executed.")

# pylint: disable=unused-argument
@deallocate_prim.def_abstract_eval
def _(*wires):
def _deallocate_primitive_abstract_eval(*wires):
return []


Expand Down
4 changes: 2 additions & 2 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _(self, x, *dyn_shape, shape, broadcast_dimensions, sharding):

# pylint: disable=unused-argument
@PlxprInterpreter.register_primitive(jax.lax.iota_p)
def _(self, *dyn_shape, dimension, dtype, shape, sharding):
def _iota_primitive(self, *dyn_shape, dimension, dtype, shape, sharding):
"""Handle the iota primitive created by jnp.arange

>>> import jax
Expand Down Expand Up @@ -646,7 +646,7 @@ class FlattenedInterpreter(PlxprInterpreter):


@FlattenedInterpreter.register_primitive(pjit_p)
def _(self, *invals, jaxpr, **params):
def _pjit_primitive(self, *invals, jaxpr, **params):
if jax.config.jax_dynamic_shapes:
# just evaluate it so it doesn't throw dynamic shape errors
return copy(self).eval(jaxpr.jaxpr, jaxpr.consts, *invals)
Expand Down
6 changes: 4 additions & 2 deletions pennylane/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def _get_for_loop_qfunc_prim():

# pylint: disable=too-many-arguments
@for_loop_prim.def_impl
def _(start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice):
def _impl(
start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):

consts = args[consts_slice]
init_state = args[args_slice]
Expand All @@ -296,7 +298,7 @@ def _(start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstrac

# pylint: disable=unused-argument
@for_loop_prim.def_abstract_eval
def _(start, stop, step, *args, args_slice, abstract_shapes_slice, **_):
def _abstract_eval(start, stop, step, *args, args_slice, abstract_shapes_slice, **_):
return args[abstract_shapes_slice] + args[args_slice]

return for_loop_prim
Expand Down
4 changes: 2 additions & 2 deletions pennylane/control_flow/while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _get_while_loop_qfunc_prim():
register_custom_staging_rule(while_loop_prim, lambda params: params["jaxpr_body_fn"].outvars)

@while_loop_prim.def_impl
def _(
def _impl(
*args,
jaxpr_body_fn,
jaxpr_cond_fn,
Expand All @@ -253,7 +253,7 @@ def _(
return fn_res

@while_loop_prim.def_abstract_eval
def _(*args, args_slice, **__):
def _abstract_eval(*args, args_slice, **__):
return args[args_slice]

return while_loop_prim
Expand Down
3 changes: 3 additions & 0 deletions pennylane/data/base/typing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def get_type_str(cls: type | str | None) -> str: # pylint: disable=too-many-ret

Otherwise, returns the fully-qualified class name, including the module.
"""
# pylint: disable=unidiomatic-typecheck
# Keep this check as it ensures that get_type_str(type(None)) = 'None'
# rather than `NoneType`.
if cls is None or cls is type(None):
return "None"

Expand Down
6 changes: 4 additions & 2 deletions pennylane/decomposition/collect_resource_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def interpret_operation(self, op):


@CollectResourceOps.register_primitive(adjoint_transform_prim)
def _(self, *invals, jaxpr, lazy, n_consts): # pylint: disable=unused-argument
def _adjoint_transform_prim(
self, *invals, jaxpr, lazy, n_consts
): # pylint: disable=unused-argument
"""Collect all operations in the base plxpr and create adjoint resource ops with them."""
consts = invals[:n_consts]
args = invals[n_consts:]
Expand All @@ -47,7 +49,7 @@ def _(self, *invals, jaxpr, lazy, n_consts): # pylint: disable=unused-argument


@CollectResourceOps.register_primitive(ctrl_transform_prim)
def _(self, *invals, n_control, jaxpr, n_consts, **params):
def _ctrl_transform_prim(self, *invals, n_control, jaxpr, n_consts, **params):
"""Collect all operations in the target plxpr and create controlled resource ops with them."""

consts = invals[:n_consts]
Expand Down
24 changes: 9 additions & 15 deletions pennylane/devices/default_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,33 +517,27 @@ def photon_number(cov, mu, params, hbar=2.0):
return ex, var


def homodyne(phi=None):
def homodyne(phi: float | None = None):
"""Function factory that returns the Homodyne expectation of a one mode state.

Args:
phi (float): the default phase space axis to perform the Homodyne measurement
phi (Optional[float]): the default phase space axis to perform the Homodyne measurement

Returns:
function: A function that accepts a single mode means vector, covariance matrix,
and phase space angle phi, and returns the quadrature expectation
value and variance.
"""
if phi is not None:

def _homodyne(cov, mu, params, hbar=2.0):
"""Arbitrary angle homodyne expectation."""
# pylint: disable=unused-argument
rot = rotation(phi)
muphi = rot.T @ mu
covphi = rot.T @ cov @ rot
return muphi[0], covphi[0, 0]
# pylint: disable=unused-argument
def _homodyne(cov, mu, params, hbar=2.0):
"""Calculates the arbitrary angle homodyne expectation."""

return _homodyne
# Use the fixed outer `phi` if it was provided,
# otherwise use the dynamic `phi` from the parameters.
measurement_phi = phi if phi is not None else params[0]

def _homodyne(cov, mu, params, hbar=2.0):
"""Arbitrary angle homodyne expectation."""
# pylint: disable=unused-argument
rot = rotation(params[0])
rot = rotation(measurement_phi)
muphi = rot.T @ mu
covphi = rot.T @ cov @ rot
return muphi[0], covphi[0, 0]
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/dq_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _(self, *invals, reset, postselect):


@DefaultQubitInterpreter.register_primitive(adjoint_transform_prim)
def _(self, *invals, jaxpr, n_consts, lazy=True):
def _adjoint_transform_prim(self, *invals, jaxpr, n_consts, lazy=True):
consts = invals[:n_consts]
args = invals[n_consts:]
recorder = CollectOpsandMeas()
Expand All @@ -251,7 +251,7 @@ def _(self, *invals, jaxpr, n_consts, lazy=True):

# pylint: disable=too-many-arguments
@DefaultQubitInterpreter.register_primitive(ctrl_transform_prim)
def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts):
def _ctrl_transform_prim(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts):
consts = invals[:n_consts]
control_wires = invals[-n_control:]
args = invals[n_consts:-n_control]
Expand Down
6 changes: 3 additions & 3 deletions pennylane/ftqc/parametric_midmeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _create_parametrized_mid_measure_primitive():
measure_in_basis_p = QmlPrimitive("measure_in_basis")

@measure_in_basis_p.def_impl
def _(wires, angle=0.0, plane="ZX", reset=False, postselect=None):
def _impl(wires, angle=0.0, plane="ZX", reset=False, postselect=None):
return _measure_impl(
wires,
measurement_class=ParametricMidMeasureMP,
Expand All @@ -66,7 +66,7 @@ def _(wires, angle=0.0, plane="ZX", reset=False, postselect=None):
)

@measure_in_basis_p.def_abstract_eval
def _(*_, **__):
def _abstract_eval(*_, **__):
return jax.core.ShapedArray((), jax.numpy.bool)

return measure_in_basis_p
Expand Down Expand Up @@ -581,7 +581,7 @@ def diagonalizing_gates(self):


@_add_operation_to_drawer.register
def _(op: ParametricMidMeasureMP, drawer, layer, _):
def _parametric_midmeasure(op: ParametricMidMeasureMP, drawer, layer, _):
if isinstance(op, XMidMeasureMP):
text = "X"
elif isinstance(op, YMidMeasureMP):
Expand Down
3 changes: 2 additions & 1 deletion pennylane/io/qasm_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _eval_unary_op(operand: any, operator: str, line: int):
) # pragma: no cover


# pylint: disable = too-many-branches
def _eval_assignment(lhs: any, operator: str, value: any, line: int):
"""
Evaluates an assignment.
Expand Down Expand Up @@ -136,7 +137,7 @@ def _eval_assignment(lhs: any, operator: str, value: any, line: int):
return lhs


# pylint: disable=too-many-return-statements
# pylint: disable=too-many-return-statements, too-many-branches
def _eval_binary_op(lhs: any, operator: str, rhs: any, line: int):
"""
Evaluates a binary operator.
Expand Down
12 changes: 6 additions & 6 deletions pennylane/measurements/capture_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ def create_measurement_obs_primitive(
primitive.prim_type = "measurement"

@primitive.def_impl
def _(obs, **kwargs):
def _impl(obs, **kwargs):
return type.__call__(measurement_type, obs=obs, **kwargs)

abstract_type = _get_abstract_measurement()

@primitive.def_abstract_eval
def _(*_, **__):
def _abstract_eval(*_, **__):
abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access
return abstract_type(abstract_eval, n_wires=None)

Expand Down Expand Up @@ -169,13 +169,13 @@ def create_measurement_mcm_primitive(
primitive.prim_type = "measurement"

@primitive.def_impl
def _(*mcms, single_mcm=True, **kwargs):
def _impl(*mcms, single_mcm=True, **kwargs):
return type.__call__(measurement_type, obs=mcms[0] if single_mcm else mcms, **kwargs)

abstract_type = _get_abstract_measurement()

@primitive.def_abstract_eval
def _(*mcms, **__):
def _abstract_eval(*mcms, **__):
abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access
return abstract_type(abstract_eval, n_wires=len(mcms))

Expand Down Expand Up @@ -205,7 +205,7 @@ def create_measurement_wires_primitive(
primitive.prim_type = "measurement"

@primitive.def_impl
def _(*args, has_eigvals=False, **kwargs):
def _impl(*args, has_eigvals=False, **kwargs):
if has_eigvals:
wires = Wires(tuple(w if is_abstract(w) else int(w) for w in args[:-1]))
kwargs["eigvals"] = args[-1]
Expand All @@ -217,7 +217,7 @@ def _(*args, has_eigvals=False, **kwargs):
abstract_type = _get_abstract_measurement()

@primitive.def_abstract_eval
def _(*args, has_eigvals=False, **_):
def _abstract_eval(*args, has_eigvals=False, **_):
abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access
n_wires = len(args) - 1 if has_eigvals else len(args)
return abstract_type(abstract_eval, n_wires=n_wires, has_eigvals=has_eigvals)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _include_all_outcomes(self, outcome_counts: dict) -> None:
CountsMP._wires_primitive.multiple_results = True

@CountsMP._wires_primitive.def_impl
def _(*args, **kwargs):
def _impl(*args, **kwargs):
raise NotImplementedError("Counts has no execution implementation with program capture.")

def _keys_eval(n_wires=None, has_eigvals=False, shots=None, num_device_wires=0):
Expand All @@ -328,7 +328,7 @@ def _values_eval(n_wires=None, has_eigvals=False, shots=None, num_device_wires=0
abstract_mp = _get_abstract_measurement()

@CountsMP._wires_primitive.def_abstract_eval
def _(*args, has_eigvals=False, all_outcomes=False):
def _abstract_eval(*args, has_eigvals=False, all_outcomes=False):
if not all_outcomes:
warnings.warn(
"all_outcomes=False is unsupported with program capture and qjit. Using all_outcomes=True",
Expand Down
4 changes: 2 additions & 2 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def _create_mid_measure_primitive():
mid_measure_p = QmlPrimitive("measure")

@mid_measure_p.def_impl
def _(wires, reset=False, postselect=None):
def _impl(wires, reset=False, postselect=None):
return _measure_impl(wires, reset=reset, postselect=postselect)

@mid_measure_p.def_abstract_eval
def _(*_, **__):
def _abstract_eval(*_, **__):
dtype = jax.numpy.int64 if jax.config.jax_enable_x64 else jax.numpy.int32
return jax.core.ShapedArray((), dtype)

Expand Down
5 changes: 4 additions & 1 deletion pennylane/noise/insert_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,10 @@ def stop_at(obj):

if req_ops:
for operation in req_ops:
if operation == type(circuit_op):
# Use `isinstance` rather than checking `operation == type(circuit_op)`
# circuit_op is an instance of an operation.
# operation is a type; either Operator or some subclass of Operator.
if isinstance(circuit_op, operation):
for w in circuit_op.wires:
sub_tape = make_qscript(op)(*op_args, wires=w)
new_operations.extend(sub_tape.operations)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def create_operator_primitive(
primitive.prim_type = "operator"

@primitive.def_impl
def _(*args, **kwargs):
def _impl(*args, **kwargs):
if "n_wires" not in kwargs:
return type.__call__(operator_type, *args, **kwargs)
n_wires = kwargs.pop("n_wires")
Expand All @@ -357,7 +357,7 @@ def _(*args, **kwargs):
abstract_type = _get_abstract_operator()

@primitive.def_abstract_eval
def _(*_, **__):
def _abstract_eval(*_, **__):
return abstract_type()

return primitive
Expand Down
Loading