Skip to content

Commit

Permalink
[Program Capture] Capture & execute qml.grad in plxpr (#6120)
Browse files Browse the repository at this point in the history
**Context:**
The new `qml.capture` module does not support differentiation yet.

**Description of the Change:**
This PR takes the first step towards differentiability in plxpr.
It adds the capability of capturing `qml.grad` as a "nested jaxpr"
primitive.
When executing the captured program, `qml.grad` is essentially changed
to `jax.grad`, because executing Autograd autodifferentiation within the
Jaxpr ecosystem is not sensible.

**Benefits:**
Capture first differentiation instructions

**Possible Drawbacks:**
The current implementation requires a `jvp` construction for every
evaluation of a QNode gradient. This means that this JVP function is
reconstructed for every evaluation call, if I'm not mistaken, making the
code significantly less performant with `capture` than without. Of
course, the longer term plan is to process the plxpr into lower-level
code by lowering the `grad` primitive itself, in which case this problem
goes away.
A similar redundancy is implemented in `QNode`: Whenever a `qnode`
primitive is evaluated, a new `QNode` is created (and only ever
evaluated once). This disables caching, for example, unless a cache is
passed around explicitly.

**Related GitHub Issues:**

[sc-71858]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
2 people authored and mudit2812 committed Sep 10, 2024
1 parent 1904af5 commit 0ecc115
Show file tree
Hide file tree
Showing 17 changed files with 609 additions and 12 deletions.
12 changes: 10 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
# Release 0.39.0-dev (development release)

<h3>New features since last release</h3>

<h3>Improvements 🛠</h3>

<h4>Capturing and representing hybrid programs</h4>

* Differentiation of hybrid programs via `qml.grad` can now be captured into plxpr.
When evaluating a captured `qml.grad` instruction, it will dispatch to `jax.grad`,
which differs from the Autograd implementation of `qml.grad` itself.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)

Expand Down Expand Up @@ -71,4 +78,5 @@ Utkarsh Azad,
Lillian M. A. Frederiksen,
Christina Lee,
William Maxwell,
Lee J. O'Riordan,
Lee J. O'Riordan,
David Wierichs,
28 changes: 25 additions & 3 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,40 @@
This module contains the autograd wrappers :class:`grad` and :func:`jacobian`
"""
import warnings
from functools import partial, wraps

from autograd import jacobian as _jacobian
from autograd.core import make_vjp as _make_vjp
from autograd.extend import vspace
from autograd.numpy.numpy_boxes import ArrayBox
from autograd.wrap_util import unary_to_nary

from pennylane.capture import enabled
from pennylane.capture.capture_diff import _get_grad_prim
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

make_vjp = unary_to_nary(_make_vjp)


def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
import jax # pylint: disable=import-outside-toplevel

if isinstance(argnum, int):
argnum = [argnum]
if argnum is None:
argnum = [0]

@wraps(func)
def new_func(*args, **kwargs):
jaxpr = jax.make_jaxpr(partial(func, **kwargs))(*args)
prim_kwargs = {"argnum": argnum, "jaxpr": jaxpr.jaxpr, "n_consts": len(jaxpr.consts)}
return diff_prim.bind(*jaxpr.consts, *args, **prim_kwargs, method=method, h=h)

return new_func


class grad:
"""Returns the gradient as a callable function of hybrid quantum-classical functions.
:func:`~.qjit` and Autograd compatible.
Expand Down Expand Up @@ -96,10 +117,11 @@ def __new__(cls, func, argnum=None, method=None, h=None):
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.grad(func, method=method, h=h, argnums=argnum)

if enabled():
return _capture_diff(func, argnum, _get_grad_prim(), method=method, h=h)

if method or h: # pragma: no cover
raise ValueError(
f"Invalid values for 'method={method}' and 'h={h}' in interpreted mode"
)
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")

return super().__new__(cls)

Expand Down
84 changes: 84 additions & 0 deletions pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule offers differentiation-related primitives and types for
the PennyLane capture module.
"""
from functools import lru_cache

has_jax = True
try:
import jax
except ImportError:
has_jax = False


@lru_cache
def create_non_jvp_primitive():
"""Create a primitive type ``NonJVPPrimitive``, which binds to JAX's JVPTrace
like a standard Python function and otherwise behaves like jax.core.Primitive.
"""

if not has_jax: # pragma: no cover
return None

# pylint: disable=too-few-public-methods
class NonJVPPrimitive(jax.core.Primitive):
"""A subclass to JAX's Primitive that works like a Python function
when evaluating JVPTracers."""

def bind_with_trace(self, trace, args, params):
"""Bind the ``NonJVPPrimitive`` with a trace. If the trace is a ``JVPTrace``,
binding falls back to a standard Python function call. Otherwise, the
bind call of JAX's standard Primitive is used."""
if isinstance(trace, jax.interpreters.ad.JVPTrace):
return self.impl(*args, **params)
return super().bind_with_trace(trace, args, params)

return NonJVPPrimitive


@lru_cache
def _get_grad_prim():
"""Create a primitive for gradient computations.
This primitive is used when capturing ``qml.grad``.
"""
if not has_jax: # pragma: no cover
return None

grad_prim = create_non_jvp_primitive()("grad")
grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init

# pylint: disable=too-many-arguments
@grad_prim.def_impl
def _(*args, argnum, jaxpr, n_consts, method, h):
if method or h: # pragma: no cover
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")

consts = args[:n_consts]
args = args[n_consts:]

def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)[0]

return jax.grad(func, argnums=argnum)(*args)

# pylint: disable=unused-argument
@grad_prim.def_abstract_eval
def _(*args, argnum, jaxpr, n_consts, method, h):
if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != ():
raise TypeError("Grad only applies to scalar-output functions. Try jacobian.")
return tuple(jaxpr.invars[i].aval for i in argnum)

return grad_prim
4 changes: 3 additions & 1 deletion pennylane/capture/capture_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import pennylane as qml

from .capture_diff import create_non_jvp_primitive

has_jax = True
try:
import jax
Expand Down Expand Up @@ -101,7 +103,7 @@ def create_operator_primitive(
if not has_jax:
return None

primitive = jax.core.Primitive(operator_type.__name__)
primitive = create_non_jvp_primitive()(operator_type.__name__)

@primitive.def_impl
def _(*args, **kwargs):
Expand Down
7 changes: 7 additions & 0 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
has_jax = True
try:
import jax
from jax.interpreters import ad

except ImportError:
has_jax = False

Expand Down Expand Up @@ -80,6 +82,11 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

def _qnode_jvp(*args_and_tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), *args_and_tangents)

ad.primitive_jvps[qnode_prim] = _qnode_jvp

return qnode_prim


Expand Down
4 changes: 3 additions & 1 deletion pennylane/capture/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
It has a jax dependency and should be located in a standard import path.
"""

from pennylane.compiler.qjit_api import _get_for_loop_qfunc_prim, _get_while_loop_qfunc_prim
from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim
from pennylane.ops.op_math.condition import _get_cond_qfunc_prim
from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim

from .capture_diff import _get_grad_prim
from .capture_measurements import _get_abstract_measurement
from .capture_operators import _get_abstract_operator
from .capture_qnode import _get_qnode_prim
Expand All @@ -31,6 +31,7 @@
AbstractMeasurement = _get_abstract_measurement()
adjoint_transform_prim = _get_adjoint_qfunc_prim()
ctrl_transform_prim = _get_ctrl_qfunc_prim()
grad_prim = _get_grad_prim()
qnode_prim = _get_qnode_prim()
cond_prim = _get_cond_qfunc_prim()
for_loop_prim = _get_for_loop_qfunc_prim()
Expand All @@ -42,6 +43,7 @@
"AbstractMeasurement",
"adjoint_transform_prim",
"ctrl_transform_prim",
"grad_prim",
"qnode_prim",
"cond_prim",
"for_loop_prim",
Expand Down
5 changes: 3 additions & 2 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Callable

import pennylane as qml
from pennylane.capture.capture_diff import create_non_jvp_primitive
from pennylane.capture.flatfn import FlatFn

from .compiler import (
Expand Down Expand Up @@ -406,7 +407,7 @@ def _get_while_loop_qfunc_prim():

import jax # pylint: disable=import-outside-toplevel

while_loop_prim = jax.core.Primitive("while_loop")
while_loop_prim = create_non_jvp_primitive()("while_loop")
while_loop_prim.multiple_results = True

@while_loop_prim.def_impl
Expand Down Expand Up @@ -621,7 +622,7 @@ def _get_for_loop_qfunc_prim():

import jax # pylint: disable=import-outside-toplevel

for_loop_prim = jax.core.Primitive("for_loop")
for_loop_prim = create_non_jvp_primitive()("for_loop")
for_loop_prim.multiple_results = True

@for_loop_prim.def_impl
Expand Down
3 changes: 2 additions & 1 deletion pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable, overload

import pennylane as qml
from pennylane.capture.capture_diff import create_non_jvp_primitive
from pennylane.compiler import compiler
from pennylane.math import conj, moveaxis, transpose
from pennylane.operation import Observable, Operation, Operator
Expand Down Expand Up @@ -192,7 +193,7 @@ def _get_adjoint_qfunc_prim():
# if capture is enabled, jax should be installed
import jax # pylint: disable=import-outside-toplevel

adjoint_prim = jax.core.Primitive("adjoint_transform")
adjoint_prim = create_non_jvp_primitive()("adjoint_transform")
adjoint_prim.multiple_results = True

@adjoint_prim.def_impl
Expand Down
3 changes: 2 additions & 1 deletion pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import pennylane as qml
from pennylane import QueuingManager
from pennylane.capture.capture_diff import create_non_jvp_primitive
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.measurements import MeasurementValue
Expand Down Expand Up @@ -688,7 +689,7 @@ def _get_cond_qfunc_prim():

import jax # pylint: disable=import-outside-toplevel

cond_prim = jax.core.Primitive("cond")
cond_prim = create_non_jvp_primitive()("cond")
cond_prim.multiple_results = True

@cond_prim.def_impl
Expand Down
3 changes: 2 additions & 1 deletion pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pennylane as qml
from pennylane import math as qmlmath
from pennylane import operation
from pennylane.capture.capture_diff import create_non_jvp_primitive
from pennylane.compiler import compiler
from pennylane.operation import Operator
from pennylane.wires import Wires
Expand Down Expand Up @@ -231,7 +232,7 @@ def _get_ctrl_qfunc_prim():
# if capture is enabled, jax should be installed
import jax # pylint: disable=import-outside-toplevel

ctrl_prim = jax.core.Primitive("ctrl_transform")
ctrl_prim = create_non_jvp_primitive()("ctrl_transform")
ctrl_prim.multiple_results = True

@ctrl_prim.def_impl
Expand Down
37 changes: 37 additions & 0 deletions tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,43 @@ def test_func(pred):
result = test_func(selector)(arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10.0, 2),
(0, 10.0, 3),
],
)
def test_gradient(self, testing_functions, selector, arg, expected, decorator):
"""Test the gradient of the conditional."""
from pennylane.capture.primitives import grad_prim

true_fn, false_fn, _, _, _, _ = testing_functions

def func(pred):
if decorator:
conditional = qml.cond(pred > 0)(true_fn)
conditional.otherwise(false_fn)
return conditional

return qml.cond(
pred > 0,
true_fn,
false_fn,
)

test_func = qml.grad(func(selector))
correct_func = jax.grad(func(selector))
assert np.allclose(correct_func(arg), expected)
assert np.allclose(test_func(arg), correct_func(arg))

jaxpr = jax.make_jaxpr(test_func)(arg)
assert len(jaxpr.eqns) == 1
assert jaxpr.eqns[0].primitive == grad_prim

manual_res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg)
assert np.allclose(manual_res, correct_func(arg))

@pytest.mark.parametrize(
"selector, arg, expected",
[
Expand Down
Loading

0 comments on commit 0ecc115

Please sign in to comment.