Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograph automatic conversion of in-place operator-based array updates #1143

Merged
merged 19 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions doc/dev/autograph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,8 @@ Notice that ``autograph=True`` must be set in order to process the
``autograph_include`` list, otherwise an error will be reported.


In-place JAX array assignments
------------------------------
In-place JAX array updates
--------------------------

To update array values when using JAX, the `JAX syntax for array assignment
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#array-updates-x-at-idx-set-y>`__
Expand Down Expand Up @@ -987,3 +987,41 @@ of standard Python array assignment syntax:
Array([25., 2., 0., 2.75, 0., 3.5, 0., 4.25, 0., 5., 0.], dtype=float64)

Under the hood, Catalyst converts anything coming in the latter notation into the former one.

Similarly, to update array values with an operation when using JAX, the JAX syntax for array
update (which uses the array `at` and the `add`, `multiply`, etc. methods) must be used:

>>> @qjit(autograph=True)
... def f(x):
... first_dim = x.shape[0]
... result = jnp.copy(x)
...
... for i in range(first_dim):
... result = result.at[i].multiply(2)
...
... return result

Again, if updating a single index or slice of the array, then Autograph supports conversion of
standard Python array operator assignment syntax for the equivalent in-place expressions
listed in the `JAX documentation for jax.numpy.ndarray.at
<https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at>`__:

>>> @qjit(autograph=True)
... def f(x):
... first_dim = x.shape[0]
... result = jnp.copy(x)
...
... for i in range(first_dim):
... result[i] *= 2
...
... return result

Under the hood, Catalyst converts anything coming in the latter notation into the former one.

The list of supported operators includes:
- ``=`` (set)
- ``+=`` (add)
- ``-=`` (add with negation)
- ``*=`` (multiply)
- ``/=`` (divide)
- ``**=`` (power)
28 changes: 28 additions & 0 deletions doc/releases/changelog-dev.md
mehrdad2m marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,34 @@
Available MLIR passes are now documented and available within the
[catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

* Catalyst Autograph now supports updating a single index or a slice of JAX arrays using Python's array assignment operator syntax.
[(#769)](https://github.com/PennyLaneAI/catalyst/pull/769)
[(#1143)](https://github.com/PennyLaneAI/catalyst/pull/1143)

Using operator assignment syntax in favor of `at...op` expressions is now possible for the following operations:
* `x[i] += y` in favor of `x.at[i].add(y)`
* `x[i] -= y` in favor of `x.at[i].add(-y)`
* `x[i] *= y` in favor of `x.at[i].multiply(y)`
* `x[i] /= y` in favor of `x.at[i].divide(y)`
* `x[i] **= y` in favor of `x.at[i].power(y)`

```python
@qjit(autograph=True)
def f(x):
first_dim = x.shape[0]
result = jnp.copy(x)

for i in range(first_dim):
result[i] *= 2 # This is now supported

return result
```

```pycon
>>> f(jnp.array([1, 2, 3]))
Array([2, 4, 6], dtype=int64)
```

<h3>Improvements</h3>

* Bufferization of `gradient.ForwardOp` and `gradient.ReverseOp` now requires 3 steps: `gradient-preprocessing`,
Expand Down
48 changes: 45 additions & 3 deletions frontend/catalyst/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
import copy
import functools
import operator
import warnings
from typing import Any, Callable, Iterator, SupportsIndex, Tuple, Union

Expand Down Expand Up @@ -47,6 +48,7 @@
"or_",
"not_",
"set_item",
"update_item_with_op",
]


Expand Down Expand Up @@ -582,9 +584,8 @@ def qnode_call_wrapper():
def set_item(target, i, x):
"""An implementation of the AutoGraph 'set_item' function. The interface is defined by
AutoGraph, here we merely provide an implementation of it in terms of Catalyst primitives.
The idea is to accept the much simpler single index assigment syntax for Jax arrays,
to subsequently transform it under the hood into the set of 'at' and 'set' calls that
Autograph supports. E.g.:
The idea is to accept a simple assigment syntax for Jax arrays, to subsequently transform
it under the hood into the set of 'at' and 'set' calls that Autograph supports. E.g.:
target[i] = x -> target = target.at[i].set(x)

.. note::
Expand All @@ -607,6 +608,47 @@ def set_item(target, i, x):
return target


def update_item_with_op(target, index, x, op):
"""An implementation of the 'update_item_with_op' function from operator_update. The interface
is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an
implementation in terms of Catalyst primitives. The idea is to accept an operator assignment
syntax for Jax arrays, to subsequently transform it under the hood into the set of 'at' and
operator calls that Autograph supports. E.g.:
target[i] **= x -> target = target.at[i].power(x)

.. note::
For this feature to work, 'converter.Feature.LISTS' had to be added to the
TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst
Autograph transformer. If you create a new transformer and want to support this feature,
make sure you enable such option there as well.
"""
# Mapping of the gast attributes to the corresponding JAX operation
gast_op_map = {"mult": "multiply", "div": "divide", "add": "add", "sub": "add", "pow": "power"}
# Mapping of the gast attributes to the corresponding in-place operation
inplace_operation_map = {
"mult": "mul",
"div": "truediv",
"add": "add",
"sub": "add",
"pow": "pow",
}
## For sub, we need to use add and negate the value of x
if op == "sub":
x = -x

# Apply the 'at...op' transformation only to Jax arrays.
# Otherwise, fallback to Python's default syntax.
if isinstance(target, DynamicJaxprTracer):
if isinstance(index, slice):
target = getattr(target.at[index.start : index.stop : index.step], gast_op_map[op])(x)
else:
target = getattr(target.at[index], gast_op_map[op])(x)
else:
# Use Python's in-place operator
target[index] = getattr(operator, f"__i{inplace_operation_map[op]}__")(target[index], x)
return target


class CRange:
"""Catalyst range object.

Expand Down
93 changes: 93 additions & 0 deletions frontend/catalyst/autograph/operator_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Converter for array element operator assignment."""

import gast
from malt.core import converter
from malt.pyct import templates


# TODO: The methods from this class should be moved to the SliceTransformer class in DiastaticMalt
class SingleIndexArrayOperatorUpdateTransformer(converter.Base):
"""Converts array element operator assignment statements into calls to update_item_with_{op},
where op is one of the following:

- `add` corresponding to `+=`
- `sub` to `-=`
- `mult` to `*=`
- `div` to `/=`
- `pow` to `**=`
"""

def _process_single_update(self, target, op, value):
if not isinstance(target, gast.Subscript):
return None
s = target.slice
if isinstance(s, (gast.Tuple, gast.Call)):
return None
if not isinstance(op, (gast.Mult, gast.Add, gast.Sub, gast.Div, gast.Pow)):
return None

template = f"""
target = ag__.update_item_with_op(target, index, x, "{type(op).__name__.lower()}")
"""
lower, upper, step = None, None, None

if isinstance(s, (gast.Slice)):
# Replace unused arguments in template with "None" to preserve each arguments' position.
# templates.replace ignores None and does not accept string so change is applied here.
lower_str = "lower" if s.lower is not None else "None"
upper_str = "upper" if s.upper is not None else "None"
step_str = "step" if s.step is not None else "None"
template = template.replace("index", f"slice({lower_str}, {upper_str}, {step_str})")

lower, upper, step = s.lower, s.upper, s.step

return templates.replace(
template,
target=target.value,
index=target.slice,
lower=lower,
upper=upper,
step=step,
x=value,
)

def visit_AugAssign(self, node):
"""The AugAssign node is replaced with a call to ag__.update_item_with_{op}
when its target is a single index array subscript and its op is an arithmetic
operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is.

Example:
`x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)`
`x[i] ^= y` remains unchanged
"""
node = self.generic_visit(node)
replacement = self._process_single_update(node.target, node.op, node.value)
if replacement is not None:
return replacement
return node


def transform(node, ctx):
"""Replace an AugAssign node with a call to ag__.update_item_with_{op}
when the its target is a single index array subscript and its op is an arithmetic
operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is.

Example:
`x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)`
`x[i] ^= y` remains unchanged
"""
return SingleIndexArrayOperatorUpdateTransformer(ctx).visit(node)
17 changes: 16 additions & 1 deletion frontend/catalyst/autograph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from malt.impl.api import PyToPy

import catalyst
from catalyst.autograph import ag_primitives
from catalyst.autograph import ag_primitives, operator_update
from catalyst.utils.exceptions import AutoGraphError


Expand Down Expand Up @@ -116,6 +116,21 @@ def get_cached_function(self, fn):

return new_fn

def transform_ast(self, node, ctx):
"""Overload of PyToPy.transform_ast from DiastaticMalt

.. note::
Once the operator_update interface has been migrated to the
DiastaticMalt project, this overload can be deleted."""
# The operator_update transform would be more correct if placed with
# slices.transform in PyToPy.transform_ast in DiastaticMalt rather than
# at the beginning of the transformation. operator_update.transform
# should come after the unsupported features check and intial analysis,
# but it fails if it does not come before variables.transform.
node = operator_update.transform(node, ctx)
node = super().transform_ast(node, ctx)
return node


def run_autograph(fn):
"""Decorator that converts the given function into graph form."""
Expand Down
41 changes: 41 additions & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,47 @@
be unrolled during tracing, "copy-pasting" the body 5 times into the program rather than
appearing as is.

.. details::
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
:title: In-place JAX array updates with Autograph
paul0403 marked this conversation as resolved.
Show resolved Hide resolved

To update array values when using JAX, the JAX syntax for array modification
(which uses methods like ``at``, ``set``, ``multiply``, etc) must be used:

.. code-block:: python

@qjit(autograph=True)
def f(x):
first_dim = x.shape[0]
result = jnp.empty((first_dim,), dtype=x.dtype)
for i in range(first_dim):
result = result.at[i].set(x[i])
result = result.at[i].multiply(10)
result = result.at[i].add(5)

return result

However, if updating a single index or slice of the array, Autograph supports conversion of
Python's standard arithmatic array assignment operators to the equivalent in-place
expressions listed in the JAX documentation for ``jax.numpy.ndarray.at``:

.. code-block:: python

@qjit(autograph=True)
def f(x):
first_dim = x.shape[0]
result = jnp.empty((first_dim,), dtype=x.dtype)
for i in range(first_dim):
result[i] = x[i]
result[i] *= 10
result[i] += 5

return result

Under the hood, Catalyst converts anything coming in the latter notation into the
former one.

The list of supported operators includes: ``=``, ``+=``, ``-=``, ``*=``, ``/=``, and ``**=``.

Check notice on line 281 in frontend/catalyst/jit.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jit.py#L281

Line too long (101/100) (line-too-long)

.. details::
:title: Static arguments

Expand Down
Loading
Loading