diff --git a/notebooks/book2/25/deblending-jax.ipynb b/notebooks/book2/25/deblending-jax.ipynb
new file mode 100644
index 00000000000..1338ea61dd4
--- /dev/null
+++ b/notebooks/book2/25/deblending-jax.ipynb
@@ -0,0 +1,551 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f9d9a8c6-8ea6-4f74-bf18-81570e3dcc46",
+ "metadata": {
+ "id": "f9d9a8c6-8ea6-4f74-bf18-81570e3dcc46"
+ },
+ "source": [
+ "# Iterative α-(de)Blending\n",
+ "\n",
+ "Jax version of: https://github.com/tchambon/IADB"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "8238a5b1-720f-4d5f-a4c0-48dfc488cae2",
+ "metadata": {
+ "id": "8238a5b1-720f-4d5f-a4c0-48dfc488cae2"
+ },
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import optax\n",
+ "import numpy as np\n",
+ "import jax.numpy as jnp\n",
+ "import flax.linen as nn\n",
+ "import matplotlib.pyplot as plt\n",
+ "import urllib.request\n",
+ "import matplotlib.image as mpimg\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "from numba import njit\n",
+ "from typing import Callable\n",
+ "from flax.training.train_state import TrainState"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ff3cafcc-1191-48fd-915f-aa5566d88eef",
+ "metadata": {
+ "id": "ff3cafcc-1191-48fd-915f-aa5566d88eef"
+ },
+ "outputs": [],
+ "source": [
+ "%config InlineBackend.figure_format = \"retina\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "a5d5cbe0-42d5-4297-958e-4566545f8102",
+ "metadata": {
+ "id": "a5d5cbe0-42d5-4297-958e-4566545f8102"
+ },
+ "outputs": [],
+ "source": [
+ "@njit\n",
+ "def generate_samples_from_image(image, n_data):\n",
+ " max_pdf_value = np.max(image)\n",
+ " samples = np.zeros((n_data, 2))\n",
+ " for n in range(n_data):\n",
+ " while True:\n",
+ " x, y, u = np.random.rand(3)\n",
+ " i = int(x * image.shape[1])\n",
+ " j = int(y * image.shape[2])\n",
+ "\n",
+ " if image[0, i, j, 0] / max_pdf_value >= u:\n",
+ " samples[n, 0] = x\n",
+ " samples[n, 1] = y\n",
+ " break\n",
+ " return samples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "99626dd9-77a2-40c5-b4f5-ad43d4d65e88",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 435
+ },
+ "id": "99626dd9-77a2-40c5-b4f5-ad43d4d65e88",
+ "outputId": "88a18040-aae4-42de-883f-1442b62eaeb6"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "