From ab530912cee191565c06df14699795ca18d690b6 Mon Sep 17 00:00:00 2001 From: David Ittah Date: Wed, 16 Oct 2024 18:57:41 -0400 Subject: [PATCH] Allow qml functions through autograph --- frontend/catalyst/autograph/ag_primitives.py | 2 ++ frontend/test/pytest/test_autograph.py | 30 +++++--------------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/frontend/catalyst/autograph/ag_primitives.py b/frontend/catalyst/autograph/ag_primitives.py index 19e53a5908..65913d4093 100644 --- a/frontend/catalyst/autograph/ag_primitives.py +++ b/frontend/catalyst/autograph/ag_primitives.py @@ -537,7 +537,9 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None): # HOTFIX: pass through calls of known Catalyst wrapper functions if fn in ( catalyst.adjoint, + qml.adjoint, catalyst.ctrl, + qml.ctrl, catalyst.grad, catalyst.value_and_grad, catalyst.jacobian, diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index 9e52bf53c8..85c2542949 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -25,25 +25,7 @@ from jax.errors import TracerBoolConversionError from numpy.testing import assert_allclose -from catalyst import ( - AutoGraphError, - adjoint, - autograph_source, - cond, - ctrl, - debug, - disable_autograph, - for_loop, - grad, - jacobian, - jvp, - measure, - qjit, - run_autograph, - vjp, - vmap, - while_loop, -) +from catalyst import * from catalyst.autograph.transformer import TRANSFORMER from catalyst.utils.dummy import dummy_func from catalyst.utils.exceptions import CompileError @@ -295,7 +277,8 @@ def fn(x: float): assert check_cache(inner.user_function.func) assert fn(np.pi) == -1 - def test_adjoint_wrapper(self): + @pytest.mark.parametrize("adjoint_fn", [adjoint, qml.adjoint]) + def test_adjoint_wrapper(self, adjoint_fn): """Test conversion is happening succesfully on functions wrapped with 'adjoint'.""" def inner(x): @@ -304,14 +287,15 @@ def inner(x): @qjit(autograph=True) @qml.qnode(qml.device("lightning.qubit", wires=1)) def fn(x: float): - adjoint(inner)(x) + adjoint_fn(inner)(x) return qml.probs() assert hasattr(fn.user_function, "ag_unconverted") assert check_cache(inner) assert np.allclose(fn(np.pi), [0.0, 1.0]) - def test_ctrl_wrapper(self): + @pytest.mark.parametrize("ctrl_fn", [ctrl, qml.ctrl]) + def test_ctrl_wrapper(self, ctrl_fn): """Test conversion is happening succesfully on functions wrapped with 'ctrl'.""" def inner(x): @@ -320,7 +304,7 @@ def inner(x): @qjit(autograph=True) @qml.qnode(qml.device("lightning.qubit", wires=2)) def fn(x: float): - ctrl(inner, control=1)(x) + ctrl_fn(inner, control=1)(x) return qml.probs() assert hasattr(fn.user_function, "ag_unconverted")