Skip to content

Commit

Permalink
Autograph automatic conversion of in-place operator-based array updat…
Browse files Browse the repository at this point in the history
…es (#1143)

**Context:** 

#717 added support for
converting in-place array updates (`arr[i] = x`) into the equivalent JAX
traceable code (`arr.at[i].set(x)`). This PR extends that support to
operator assignment array updates.

**Description of the Change:**

- Add new Autograph converter to map `AugAssign` ast nodes assigning to
a single index or a slice subscript to calls to `update_item_with_op`
- Implement `update_item_with_op` method that map to the corresponding
`jax.numpy.ndarray.at` equivalent methods for JAX arrays and the normal
Python operator assignment otherwise
- Overload `transform_ast` in `CatalystTransformer` to invoke the new
converter

**Benefits:** We can use `arr[i] += x` instead of `arr.at[i].add(x)`.

**Possible Drawbacks:** It would be cleaner to have the new converter
live in the DiastaticMalt project.

**Related GitHub Issues:**
#757

**Based on the solution presented in this PR:**
#769
Note that this PR was originally implemented externally by
#769. This PR aims to
revisit that PR.

---------

Co-authored-by: Spencer Comin <scomin@me.com>
  • Loading branch information
mehrdad2m and Spencer-Comin authored Sep 20, 2024
1 parent 6232296 commit 621c027
Show file tree
Hide file tree
Showing 7 changed files with 567 additions and 6 deletions.
42 changes: 40 additions & 2 deletions doc/dev/autograph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,8 @@ Notice that ``autograph=True`` must be set in order to process the
``autograph_include`` list, otherwise an error will be reported.


In-place JAX array assignments
------------------------------
In-place JAX array updates
--------------------------

To update array values when using JAX, the `JAX syntax for array assignment
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#array-updates-x-at-idx-set-y>`__
Expand Down Expand Up @@ -987,3 +987,41 @@ of standard Python array assignment syntax:
Array([25., 2., 0., 2.75, 0., 3.5, 0., 4.25, 0., 5., 0.], dtype=float64)

Under the hood, Catalyst converts anything coming in the latter notation into the former one.

Similarly, to update array values with an operation when using JAX, the JAX syntax for array
update (which uses the array `at` and the `add`, `multiply`, etc. methods) must be used:

>>> @qjit(autograph=True)
... def f(x):
... first_dim = x.shape[0]
... result = jnp.copy(x)
...
... for i in range(first_dim):
... result = result.at[i].multiply(2)
...
... return result

Again, if updating a single index or slice of the array, then Autograph supports conversion of
standard Python array operator assignment syntax for the equivalent in-place expressions
listed in the `JAX documentation for jax.numpy.ndarray.at
<https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at>`__:

>>> @qjit(autograph=True)
... def f(x):
... first_dim = x.shape[0]
... result = jnp.copy(x)
...
... for i in range(first_dim):
... result[i] *= 2
...
... return result

Under the hood, Catalyst converts anything coming in the latter notation into the former one.

The list of supported operators includes:
- ``=`` (set)
- ``+=`` (add)
- ``-=`` (add with negation)
- ``*=`` (multiply)
- ``/=`` (divide)
- ``**=`` (power)
29 changes: 29 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,34 @@
Available MLIR passes are now documented and available within the
[catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

* Catalyst Autograph now supports updating a single index or a slice of JAX arrays using Python's array assignment operator syntax.
[(#769)](https://github.com/PennyLaneAI/catalyst/pull/769)
[(#1143)](https://github.com/PennyLaneAI/catalyst/pull/1143)

Using operator assignment syntax in favor of `at...op` expressions is now possible for the following operations:
* `x[i] += y` in favor of `x.at[i].add(y)`
* `x[i] -= y` in favor of `x.at[i].add(-y)`
* `x[i] *= y` in favor of `x.at[i].multiply(y)`
* `x[i] /= y` in favor of `x.at[i].divide(y)`
* `x[i] **= y` in favor of `x.at[i].power(y)`

```python
@qjit(autograph=True)
def f(x):
first_dim = x.shape[0]
result = jnp.copy(x)

for i in range(first_dim):
result[i] *= 2 # This is now supported

return result
```

```pycon
>>> f(jnp.array([1, 2, 3]))
Array([2, 4, 6], dtype=int64)
```

<h3>Improvements</h3>

* Bufferization of `gradient.ForwardOp` and `gradient.ReverseOp` now requires 3 steps: `gradient-preprocessing`,
Expand Down Expand Up @@ -195,5 +223,6 @@ Erick Ochoa Lopez,
Mehrdad Malekmohammadi,
Paul Haochen Wang,
Sengthai Heng,
Spencer Comin,
Daniel Strano,
Raul Torres.
48 changes: 45 additions & 3 deletions frontend/catalyst/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
import copy
import functools
import operator
import warnings
from typing import Any, Callable, Iterator, SupportsIndex, Tuple, Union

Expand Down Expand Up @@ -47,6 +48,7 @@
"or_",
"not_",
"set_item",
"update_item_with_op",
]


Expand Down Expand Up @@ -582,9 +584,8 @@ def qnode_call_wrapper():
def set_item(target, i, x):
"""An implementation of the AutoGraph 'set_item' function. The interface is defined by
AutoGraph, here we merely provide an implementation of it in terms of Catalyst primitives.
The idea is to accept the much simpler single index assigment syntax for Jax arrays,
to subsequently transform it under the hood into the set of 'at' and 'set' calls that
Autograph supports. E.g.:
The idea is to accept a simple assigment syntax for Jax arrays, to subsequently transform
it under the hood into the set of 'at' and 'set' calls that Autograph supports. E.g.:
target[i] = x -> target = target.at[i].set(x)
.. note::
Expand All @@ -607,6 +608,47 @@ def set_item(target, i, x):
return target


def update_item_with_op(target, index, x, op):
"""An implementation of the 'update_item_with_op' function from operator_update. The interface
is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an
implementation in terms of Catalyst primitives. The idea is to accept an operator assignment
syntax for Jax arrays, to subsequently transform it under the hood into the set of 'at' and
operator calls that Autograph supports. E.g.:
target[i] **= x -> target = target.at[i].power(x)
.. note::
For this feature to work, 'converter.Feature.LISTS' had to be added to the
TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst
Autograph transformer. If you create a new transformer and want to support this feature,
make sure you enable such option there as well.
"""
# Mapping of the gast attributes to the corresponding JAX operation
gast_op_map = {"mult": "multiply", "div": "divide", "add": "add", "sub": "add", "pow": "power"}
# Mapping of the gast attributes to the corresponding in-place operation
inplace_operation_map = {
"mult": "mul",
"div": "truediv",
"add": "add",
"sub": "add",
"pow": "pow",
}
## For sub, we need to use add and negate the value of x
if op == "sub":
x = -x

# Apply the 'at...op' transformation only to Jax arrays.
# Otherwise, fallback to Python's default syntax.
if isinstance(target, DynamicJaxprTracer):
if isinstance(index, slice):
target = getattr(target.at[index.start : index.stop : index.step], gast_op_map[op])(x)
else:
target = getattr(target.at[index], gast_op_map[op])(x)
else:
# Use Python's in-place operator
target[index] = getattr(operator, f"__i{inplace_operation_map[op]}__")(target[index], x)
return target


class CRange:
"""Catalyst range object.
Expand Down
93 changes: 93 additions & 0 deletions frontend/catalyst/autograph/operator_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Converter for array element operator assignment."""

import gast
from malt.core import converter
from malt.pyct import templates


# TODO: The methods from this class should be moved to the SliceTransformer class in DiastaticMalt
class SingleIndexArrayOperatorUpdateTransformer(converter.Base):
"""Converts array element operator assignment statements into calls to update_item_with_{op},
where op is one of the following:
- `add` corresponding to `+=`
- `sub` to `-=`
- `mult` to `*=`
- `div` to `/=`
- `pow` to `**=`
"""

def _process_single_update(self, target, op, value):
if not isinstance(target, gast.Subscript):
return None
s = target.slice
if isinstance(s, (gast.Tuple, gast.Call)):
return None
if not isinstance(op, (gast.Mult, gast.Add, gast.Sub, gast.Div, gast.Pow)):
return None

template = f"""
target = ag__.update_item_with_op(target, index, x, "{type(op).__name__.lower()}")
"""
lower, upper, step = None, None, None

if isinstance(s, (gast.Slice)):
# Replace unused arguments in template with "None" to preserve each arguments' position.
# templates.replace ignores None and does not accept string so change is applied here.
lower_str = "lower" if s.lower is not None else "None"
upper_str = "upper" if s.upper is not None else "None"
step_str = "step" if s.step is not None else "None"
template = template.replace("index", f"slice({lower_str}, {upper_str}, {step_str})")

lower, upper, step = s.lower, s.upper, s.step

return templates.replace(
template,
target=target.value,
index=target.slice,
lower=lower,
upper=upper,
step=step,
x=value,
)

def visit_AugAssign(self, node):
"""The AugAssign node is replaced with a call to ag__.update_item_with_{op}
when its target is a single index array subscript and its op is an arithmetic
operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is.
Example:
`x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)`
`x[i] ^= y` remains unchanged
"""
node = self.generic_visit(node)
replacement = self._process_single_update(node.target, node.op, node.value)
if replacement is not None:
return replacement
return node


def transform(node, ctx):
"""Replace an AugAssign node with a call to ag__.update_item_with_{op}
when the its target is a single index array subscript and its op is an arithmetic
operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is.
Example:
`x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)`
`x[i] ^= y` remains unchanged
"""
return SingleIndexArrayOperatorUpdateTransformer(ctx).visit(node)
17 changes: 16 additions & 1 deletion frontend/catalyst/autograph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from malt.impl.api import PyToPy

import catalyst
from catalyst.autograph import ag_primitives
from catalyst.autograph import ag_primitives, operator_update
from catalyst.utils.exceptions import AutoGraphError


Expand Down Expand Up @@ -116,6 +116,21 @@ def get_cached_function(self, fn):

return new_fn

def transform_ast(self, node, ctx):
"""Overload of PyToPy.transform_ast from DiastaticMalt
.. note::
Once the operator_update interface has been migrated to the
DiastaticMalt project, this overload can be deleted."""
# The operator_update transform would be more correct if placed with
# slices.transform in PyToPy.transform_ast in DiastaticMalt rather than
# at the beginning of the transformation. operator_update.transform
# should come after the unsupported features check and intial analysis,
# but it fails if it does not come before variables.transform.
node = operator_update.transform(node, ctx)
node = super().transform_ast(node, ctx)
return node


def run_autograph(fn):
"""Decorator that converts the given function into graph form."""
Expand Down
41 changes: 41 additions & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,47 @@ def circuit(x: int):
be unrolled during tracing, "copy-pasting" the body 5 times into the program rather than
appearing as is.
.. details::
:title: In-place JAX array updates with Autograph
To update array values when using JAX, the JAX syntax for array modification
(which uses methods like ``at``, ``set``, ``multiply``, etc) must be used:
.. code-block:: python
@qjit(autograph=True)
def f(x):
first_dim = x.shape[0]
result = jnp.empty((first_dim,), dtype=x.dtype)
for i in range(first_dim):
result = result.at[i].set(x[i])
result = result.at[i].multiply(10)
result = result.at[i].add(5)
return result
However, if updating a single index or slice of the array, Autograph supports conversion of
Python's standard arithmatic array assignment operators to the equivalent in-place
expressions listed in the JAX documentation for ``jax.numpy.ndarray.at``:
.. code-block:: python
@qjit(autograph=True)
def f(x):
first_dim = x.shape[0]
result = jnp.empty((first_dim,), dtype=x.dtype)
for i in range(first_dim):
result[i] = x[i]
result[i] *= 10
result[i] += 5
return result
Under the hood, Catalyst converts anything coming in the latter notation into the
former one.
The list of supported operators includes: ``=``, ``+=``, ``-=``, ``*=``, ``/=``, and ``**=``.
.. details::
:title: Static arguments
Expand Down
Loading

0 comments on commit 621c027

Please sign in to comment.