Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Use custom lowering rules #1152

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@
* Catalyst now supports numpy 2.0
[(#1119)](https://github.com/PennyLaneAI/catalyst/pull/1119)

* Importing Catalyst will now pollute less of JAX's global variables by using `LoweringParameters`.
[(#1152)](https://github.com/PennyLaneAI/catalyst/pull/1152)

<h3>Breaking changes</h3>

* Remove `static_size` field from `AbstractQreg` class.
Expand Down
4 changes: 3 additions & 1 deletion frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@
Copyright 2021 The JAX Authors.
"""

from catalyst.jax_primitives import CUSTOM_LOWERING_RULES

Check notice on line 111 in frontend/catalyst/jax_extras/lowering.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/lowering.py#L111

Import outside toplevel (catalyst.jax_primitives.CUSTOM_LOWERING_RULES) (import-outside-toplevel)
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved

if any(lowerable_effects.filter_not_in(jaxpr.effects)): # pragma: no cover
raise ValueError(f"Cannot lower jaxpr with effects: {jaxpr.effects}")

Expand All @@ -120,7 +122,7 @@
# Create a keepalives list that will be mutated during the lowering.
keepalives = []
host_callbacks = []
lowering_params = LoweringParameters()
lowering_params = LoweringParameters(override_lowering_rules=CUSTOM_LOWERING_RULES)
ctx = ModuleContext(
backend_or_name=None,
platforms=[platform],
Expand Down
14 changes: 0 additions & 14 deletions frontend/catalyst/jax_extras/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import jax
from jax._src.lax.lax import _nary_lower_hlo
from jax._src.lax.slicing import (
_argnum_weak_type,
_gather_dtype_rule,
Expand All @@ -29,15 +28,12 @@
_sorted_dims_in_range,
standard_primitive,
)
from jax._src.lib.mlir.dialects import hlo
from jax.core import AbstractValue, Tracer, concrete_aval

__all__ = (
"get_aval2",
"_no_clean_up_dead_vars",
"_gather_shape_rule_dynamic",
"_sin_lowering2",
"_cos_lowering2",
"gather2_p",
)

Expand Down Expand Up @@ -198,13 +194,3 @@ def _gather_shape_rule_dynamic(
"gather",
weak_type_rule=_argnum_weak_type(0),
)


def _sin_lowering2(ctx, x):
"""Use hlo.sine lowering instead of the new sin lowering from jax 0.4.28"""
return _nary_lower_hlo(hlo.sine, ctx, x)


def _cos_lowering2(ctx, x):
"""Use hlo.cosine lowering instead of the new cosine lowering from jax 0.4.28"""
return _nary_lower_hlo(hlo.cosine, ctx, x)
26 changes: 6 additions & 20 deletions frontend/catalyst/jax_extras/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
trace_to_jaxpr_dynamic2,
)
from jax._src.lax.control_flow import _initial_style_jaxpr
from jax._src.lax.lax import _abstractify, cos_p, sin_p
from jax._src.lax.slicing import _gather_lower
from jax._src.lax.lax import _abstractify
from jax._src.lax.slicing import _gather_lower, gather_p
from jax._src.linear_util import annotate
from jax._src.pjit import _extract_implicit_args, _flat_axes_specs
from jax._src.source_info_util import current as jax_current
Expand Down Expand Up @@ -92,12 +92,7 @@
)
from jaxlib.xla_extension import PyTreeRegistry

from catalyst.jax_extras.patches import (
_cos_lowering2,
_sin_lowering2,
gather2_p,
get_aval2,
)
from catalyst.jax_extras.patches import gather2_p, get_aval2
from catalyst.logging import debug_logger
from catalyst.tracing.type_signatures import verify_static_argnums_type
from catalyst.utils.patching import Patcher
Expand Down Expand Up @@ -511,25 +506,16 @@ def abstractify(args, kwargs):

register_lowering(gather2_p, _gather_lower)

# TBD
register_lowering(sin_p, _sin_lowering2)
register_lowering(cos_p, _cos_lowering2)

primitive_batchers2 = jax._src.interpreters.batching.primitive_batchers.copy()
for primitive in jax._src.interpreters.batching.primitive_batchers.keys():
if primitive.name == "gather":
gather_batching_rule = jax._src.interpreters.batching.primitive_batchers[primitive]
primitive_batchers2[gather2_p] = gather_batching_rule
jax._src.interpreters.batching.primitive_batchers[gather2_p] = (
jax._src.interpreters.batching.primitive_batchers[gather_p]
)

@wraps(fun)
def make_jaxpr_f(*args, **kwargs):
# TODO: re-use `deduce_avals` here.
with Patcher(
(jax._src.interpreters.partial_eval, "get_aval", get_aval2),
(jax._src.lax.slicing, "gather_p", gather2_p),
(jax._src.interpreters.batching, "primitive_batchers", primitive_batchers2),
(jax._src.lax.lax, "_sin_lowering", _sin_lowering2),
(jax._src.lax.lax, "_cos_lowering", _cos_lowering2),
), ExitStack():
f = wrap_init(fun)
if static_argnums:
Expand Down
94 changes: 53 additions & 41 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import numpy as np
import pennylane as qml
from jax._src import api_util, core, source_info_util, util
from jax._src.lax.lax import _nary_lower_hlo, cos_p, sin_p
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax.core import AbstractValue
from jax.interpreters import mlir
from jax.tree_util import PyTreeDef, tree_unflatten
Expand Down Expand Up @@ -2190,47 +2192,57 @@ def extract_scalar(value, op, kind="parameter"):
return value


#
# registration
#

mlir.register_lowering(zne_p, _zne_lowering)
mlir.register_lowering(qdevice_p, _qdevice_lowering)
mlir.register_lowering(qalloc_p, _qalloc_lowering)
mlir.register_lowering(qdealloc_p, _qdealloc_lowering)
mlir.register_lowering(qextract_p, _qextract_lowering)
mlir.register_lowering(qinsert_p, _qinsert_lowering)
mlir.register_lowering(qinst_p, _qinst_lowering)
mlir.register_lowering(gphase_p, _gphase_lowering)
mlir.register_lowering(qunitary_p, _qunitary_lowering)
mlir.register_lowering(qmeasure_p, _qmeasure_lowering)
mlir.register_lowering(compbasis_p, _compbasis_lowering)
mlir.register_lowering(namedobs_p, _named_obs_lowering)
mlir.register_lowering(hermitian_p, _hermitian_lowering)
mlir.register_lowering(tensorobs_p, _tensor__obs_lowering)
mlir.register_lowering(hamiltonian_p, _hamiltonian_lowering)
mlir.register_lowering(sample_p, _sample_lowering)
mlir.register_lowering(counts_p, _counts_lowering)
mlir.register_lowering(expval_p, _expval_lowering)
mlir.register_lowering(var_p, _var_lowering)
mlir.register_lowering(probs_p, _probs_lowering)
mlir.register_lowering(state_p, _state_lowering)
mlir.register_lowering(cond_p, _cond_lowering)
mlir.register_lowering(while_p, _while_loop_lowering)
mlir.register_lowering(for_p, _for_loop_lowering)
mlir.register_lowering(grad_p, _grad_lowering)
mlir.register_lowering(func_p, _func_lowering)
mlir.register_lowering(jvp_p, _jvp_lowering)
mlir.register_lowering(vjp_p, _vjp_lowering)
mlir.register_lowering(adjoint_p, _adjoint_lowering)
mlir.register_lowering(print_p, _print_lowering)
mlir.register_lowering(assert_p, _assert_lowering)
mlir.register_lowering(python_callback_p, _python_callback_lowering)
mlir.register_lowering(value_and_grad_p, _value_and_grad_lowering)
mlir.register_lowering(apply_registered_pass_p, _apply_registered_pass_lowering)
mlir.register_lowering(transform_named_sequence_p, _transform_named_sequence_lowering)
mlir.register_lowering(set_state_p, _set_state_lowering)
mlir.register_lowering(set_basis_state_p, _set_basis_state_lowering)
def _sin_lowering2(ctx, x):
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
"""Use hlo.sine lowering instead of the new sin lowering from jax 0.4.28"""
return _nary_lower_hlo(hlo.sine, ctx, x)


def _cos_lowering2(ctx, x):
"""Use hlo.cosine lowering instead of the new cosine lowering from jax 0.4.28"""
return _nary_lower_hlo(hlo.cosine, ctx, x)
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved


CUSTOM_LOWERING_RULES = (
(zne_p, _zne_lowering),
(qdevice_p, _qdevice_lowering),
(qalloc_p, _qalloc_lowering),
(qdealloc_p, _qdealloc_lowering),
(qextract_p, _qextract_lowering),
(qinsert_p, _qinsert_lowering),
(qinst_p, _qinst_lowering),
(gphase_p, _gphase_lowering),
(qunitary_p, _qunitary_lowering),
(qmeasure_p, _qmeasure_lowering),
(compbasis_p, _compbasis_lowering),
(namedobs_p, _named_obs_lowering),
(hermitian_p, _hermitian_lowering),
(tensorobs_p, _tensor__obs_lowering),
(hamiltonian_p, _hamiltonian_lowering),
(sample_p, _sample_lowering),
(counts_p, _counts_lowering),
(expval_p, _expval_lowering),
(var_p, _var_lowering),
(probs_p, _probs_lowering),
(state_p, _state_lowering),
(cond_p, _cond_lowering),
(while_p, _while_loop_lowering),
(for_p, _for_loop_lowering),
(grad_p, _grad_lowering),
(func_p, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
(adjoint_p, _adjoint_lowering),
(print_p, _print_lowering),
(assert_p, _assert_lowering),
(python_callback_p, _python_callback_lowering),
(value_and_grad_p, _value_and_grad_lowering),
(apply_registered_pass_p, _apply_registered_pass_lowering),
(transform_named_sequence_p, _transform_named_sequence_lowering),
(set_state_p, _set_state_lowering),
(set_basis_state_p, _set_basis_state_lowering),
(sin_p, _sin_lowering2),
(cos_p, _cos_lowering2),
)


def _scalar_abstractify(t):
Expand Down
Loading