diff --git a/gpjax/scan.py b/gpjax/scan.py index 4ca2eebf4..fcec3aba9 100644 --- a/gpjax/scan.py +++ b/gpjax/scan.py @@ -22,7 +22,6 @@ ) import jax from jax import lax -from jax.experimental import host_callback as hcb import jax.numpy as jnp import jax.tree_util as jtu from jaxtyping import ( @@ -54,7 +53,8 @@ def _callback(cond: ScalarBool, func: Callable, *args: Any) -> None: def _do_callback(_) -> int: """Perform the callback.""" - return hcb.id_tap(func, *args, result=_dummy_result) + jax.debug.callback(func, *args) + return _dummy_result def _not_callback(_) -> int: """Do nothing.""" @@ -113,19 +113,19 @@ def vscan( _progress_bar = trange(_length) _progress_bar.set_description("Compiling...", refresh=True) - def _set_running(args: Any, transform: Any) -> None: + def _set_running(*args: Any) -> None: """Set the tqdm progress bar to running.""" _progress_bar.set_description("Running", refresh=False) - def _update_tqdm(args: Any, transform: Any) -> None: + def _update_tqdm(*args: Any) -> None: """Update the tqdm progress bar with the latest objective value.""" _value, _iter_num = args - _progress_bar.update(_iter_num) + _progress_bar.update(_iter_num.item()) if log_value and _value is not None: _progress_bar.set_postfix({"Value": f"{_value: .2f}"}) - def _close_tqdm(args: Any, transform: Any) -> None: + def _close_tqdm(*args: Any) -> None: """Close the tqdm progress bar.""" _progress_bar.close() @@ -145,16 +145,16 @@ def _body_fun(carry: Carry, iter_num_and_x: Tuple[ScalarInt, X]) -> Tuple[Carry, _is_last: bool = iter_num == _length - 1 # Update progress bar, if first of log_rate. - _callback(_is_first, _set_running, (y, log_rate)) + _callback(_is_first, _set_running) # Update progress bar, if multiple of log_rate. - _callback(_is_multiple, _update_tqdm, (y, log_rate)) + _callback(_is_multiple, _update_tqdm, y, log_rate) # Update progress bar, if remainder. - _callback(_is_remainder, _update_tqdm, (y, _remainder)) + _callback(_is_remainder, _update_tqdm, y, _remainder) # Close progress bar, if last iteration. - _callback(_is_last, _close_tqdm, (y, None)) + _callback(_is_last, _close_tqdm) return carry, y