Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Program Capture] Capture & execute qml.grad in plxpr #6120

Merged
merged 42 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
206ca6a
first prototype
dwierichs Aug 21, 2024
a466a83
first tests
dwierichs Aug 21, 2024
8b402f7
cleanup
dwierichs Aug 21, 2024
9e8f4c9
-a
dwierichs Aug 21, 2024
cfae262
changelog, lint
dwierichs Aug 21, 2024
ed8f515
add tests...
dwierichs Aug 21, 2024
79fe790
move primitive
dwierichs Aug 22, 2024
b90fa2b
import
dwierichs Aug 22, 2024
b9b6536
lint
dwierichs Aug 22, 2024
ff54b86
lint
dwierichs Aug 22, 2024
fde341f
parshift test
dwierichs Aug 22, 2024
3cd33f1
prepare jac
dwierichs Aug 22, 2024
a087e53
make grad differentiable
dwierichs Aug 22, 2024
393aa35
nested grad test
dwierichs Aug 22, 2024
f1a6b4d
lint
dwierichs Aug 22, 2024
b86619b
Merge branch 'master' into capture-grad
dwierichs Aug 22, 2024
f678a91
lint
dwierichs Aug 22, 2024
15ed932
Merge branch 'master' into capture-grad
dwierichs Aug 22, 2024
692db13
Merge branch 'master' into capture-grad
dwierichs Aug 23, 2024
0e407c3
Merge branch 'master' into capture-grad
dwierichs Aug 26, 2024
93bda87
method and h allowed in capture
dwierichs Aug 26, 2024
5bb2900
higher order primitive tests
dwierichs Aug 26, 2024
10b2899
Merge branch 'master' into capture-grad
dwierichs Aug 27, 2024
c838b4a
Apply suggestions from code review
dwierichs Aug 27, 2024
67bdeb8
Merge branch 'master' into capture-grad
dwierichs Aug 27, 2024
cb43c96
Merge branch 'master' into capture-grad
dwierichs Aug 28, 2024
aa3876e
merge
dwierichs Aug 29, 2024
389c22d
merge more
dwierichs Aug 29, 2024
de782b8
lint more
dwierichs Aug 29, 2024
a9a4472
add file
dwierichs Aug 29, 2024
769eb98
[skip ci]
dwierichs Aug 29, 2024
9a4c580
Merge branch 'master' into capture-grad
dwierichs Sep 3, 2024
e79d4a8
merge
dwierichs Sep 5, 2024
552417b
while_loop
dwierichs Sep 5, 2024
81ab600
import fix
dwierichs Sep 5, 2024
1af5b97
lint
dwierichs Sep 5, 2024
c3fbd78
fix import
dwierichs Sep 5, 2024
e73f2fb
import and skip order
dwierichs Sep 6, 2024
3d2cb89
Merge branch 'master' into capture-grad
dwierichs Sep 6, 2024
724675f
[skip ci]
dwierichs Sep 6, 2024
9fda181
Merge branch 'master' into capture-grad
dwierichs Sep 9, 2024
db99812
lint
dwierichs Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@
* `qml.for_loop` now supports `range`-like syntax with default `step=1`.
[(#6068)](https://github.com/PennyLaneAI/pennylane/pull/6068)

* Differentiation of hybrid programs via `qml.grad` can now be captured into plxpr.
When evaluating a captured `qml.grad` instruction, it will dispatch to `jax.grad`,
which differs from the Autograd implementation of `qml.grad` itself.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)

* Applying `adjoint` and `ctrl` to a quantum function can now be captured into plxpr.
Furthermore, the `qml.cond` function can be captured into plxpr.
[(#5966)](https://github.com/PennyLaneAI/pennylane/pull/5966)
Expand Down
27 changes: 24 additions & 3 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,39 @@
This module contains the autograd wrappers :class:`grad` and :func:`jacobian`
"""
import warnings
from functools import partial, wraps

from autograd import jacobian as _jacobian
from autograd.core import make_vjp as _make_vjp
from autograd.extend import vspace
from autograd.numpy.numpy_boxes import ArrayBox
from autograd.wrap_util import unary_to_nary

from pennylane.capture import create_grad_primitive, enabled
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

make_vjp = unary_to_nary(_make_vjp)


def _capture_diff(func, argnum=None, diff_prim=None):
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
"""Capture-compatible gradient computation."""
import jax # pylint: disable=import-outside-toplevel

if isinstance(argnum, int):
argnum = [argnum]
if argnum is None:
argnum = [0]
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

@wraps(func)
def new_func(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args)
prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs)

return new_func


class grad:
"""Returns the gradient as a callable function of hybrid quantum-classical functions.
:func:`~.qjit` and Autograd compatible.
Expand Down Expand Up @@ -97,9 +117,10 @@ def __new__(cls, func, argnum=None, method=None, h=None):
return ops_loader.grad(func, method=method, h=h, argnums=argnum)

if method or h: # pragma: no cover
raise ValueError(
f"Invalid values for 'method={method}' and 'h={h}' in interpreted mode"
)
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

if enabled():
return _capture_diff(func, argnum, create_grad_primitive())

return super().__new__(cls)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def _(*args, **kwargs):
from .switches import disable, enable, enabled
from .capture_meta import CaptureMeta, ABCCaptureMeta
from .primitives import (
create_grad_primitive,
create_operator_primitive,
create_measurement_obs_primitive,
create_measurement_wires_primitive,
Expand Down Expand Up @@ -172,6 +173,7 @@ def __getattr__(key):
"create_measurement_obs_primitive",
"create_measurement_wires_primitive",
"create_measurement_mcm_primitive",
"create_grad_primitive",
"qnode_call",
"AbstractOperator",
"AbstractMeasurement",
Expand Down
7 changes: 7 additions & 0 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
has_jax = True
try:
import jax
from jax.interpreters import ad

except ImportError:
has_jax = False

Expand Down Expand Up @@ -87,6 +89,11 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

def _qnode_jvp(*args_and_tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

ad.primitive_jvps[qnode_prim] = _qnode_jvp

return qnode_prim


Expand Down
56 changes: 55 additions & 1 deletion pennylane/capture/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,31 @@ def __hash__(self):
return AbstractMeasurement


@lru_cache
def create_non_jvp_primitive():
"""Create a primitive type ``NonJVPPrimitive``, which binds to JAX's JVPTrace
like a standard Python function and otherwise behaves like jax.core.Primitive.
"""

if not has_jax: # pragma: no cover
return None

# pylint: disable=too-few-public-methods
class NonJVPPrimitive(jax.core.Primitive):
"""A subclass to JAX's Primitive that works like a Python function
when evaluating JVPTracers."""

def bind_with_trace(self, trace, args, params):
"""Bind the ``NonJVPPrimitive`` with a trace. If the trace is a ``JVPTrace``,
binding falls back to a standard Python function call. Otherwise, the
bind call of JAX's standard Primitive is used."""
if isinstance(trace, jax.interpreters.ad.JVPTrace):
return self.impl(*args, **params)
return super().bind_with_trace(trace, args, params)

return NonJVPPrimitive
dwierichs marked this conversation as resolved.
Show resolved Hide resolved


def create_operator_primitive(
operator_type: Type["qml.operation.Operator"],
) -> Optional["jax.core.Primitive"]:
Expand All @@ -182,7 +207,7 @@ def create_operator_primitive(
if not has_jax:
return None

primitive = jax.core.Primitive(operator_type.__name__)
primitive = create_non_jvp_primitive()(operator_type.__name__)

@primitive.def_impl
def _(*args, **kwargs):
Expand Down Expand Up @@ -318,3 +343,32 @@ def _(*args, has_eigvals=False, **_):
return abstract_type(abstract_eval, n_wires=n_wires, has_eigvals=has_eigvals)

return primitive


@lru_cache
def create_grad_primitive():
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
"""Create a primitive for gradient computations.
This primitive is used when capturing ``qml.grad``.
"""
grad_prim = create_non_jvp_primitive()("grad")
grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init

# pylint: disable=too-many-arguments
@grad_prim.def_impl
def _(*args, argnum, jaxpr, n_consts):
consts = args[:n_consts]
args = args[n_consts:]

def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)[0]

return jax.grad(func, argnums=argnum)(*args)

# pylint: disable=unused-argument
@grad_prim.def_abstract_eval
def _(*args, argnum, jaxpr, n_consts):
if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != ():
raise TypeError("Grad only applies to scalar-output functions. Try jacobian or egrad.")
return tuple(jaxpr.invars[i].aval for i in argnum)

return grad_prim
Loading
Loading