diff --git a/README.md b/README.md index a710441..3cb9be3 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,7 @@ methods. | [`CCHVAE`](https://birkhoffg.github.io/jax-relax/methods/cchvae.html#cchvae) | Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data. | [\[6\]](https://arxiv.org/abs/1910.09398) | | [`VAECF`](https://birkhoffg.github.io/jax-relax/methods/vaecf.html#vaecf) | Parametric | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [\[7\]](https://arxiv.org/abs/1912.03277) | | [`CLUE`](https://birkhoffg.github.io/jax-relax/methods/clue.html#clue) | Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates. | [\[8\]](https://arxiv.org/abs/2006.06848) | +| [`L2C`](https://birkhoffg.github.io/jax-relax/methods/l2c.html#l2c) | Parametric | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations | [\[9\]](https://arxiv.org/abs/2209.13446) | ## Citing `ReLax` diff --git a/nbs/index.ipynb b/nbs/index.ipynb index 4e1dbcc..4ebdcdc 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -374,6 +374,7 @@ "| `CCHVAE` | Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data. | [[6]](https://arxiv.org/abs/1910.09398) |\n", "| `VAECF` | Parametric | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [[7]](https://arxiv.org/abs/1912.03277) |\n", "| `CLUE` | Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates. | [[8]](https://arxiv.org/abs/2006.06848) |\n", + "| `L2C` | Parametric | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations | [[9]](https://arxiv.org/abs/2209.13446) |\n", "\n", ": {tbl-colwidths=\"[17, 13, 65, 5]\"}\n" ] diff --git a/nbs/methods/09_l2c.ipynb b/nbs/methods/09_l2c.ipynb new file mode 100644 index 0000000..e37483f --- /dev/null +++ b/nbs/methods/09_l2c.ipynb @@ -0,0 +1,685 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# L2C\n", + "\n", + "https://arxiv.org/abs/2209.13446" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp methods.l2c" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "#| include: false\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "from ipynb_path import *\n", + "import warnings\n", + "warnings.simplefilter(action='ignore', category=FutureWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from __future__ import annotations\n", + "from relax.import_essentials import *\n", + "from relax.methods.base import ParametricCFModule\n", + "from relax.base import BaseConfig\n", + "from relax.utils import *\n", + "from relax.data_utils import Feature, FeaturesList\n", + "from relax.ml_model import MLP, MLPBlock\n", + "from relax.data_module import DataModule\n", + "from keras_core.random import SeedGenerator\n", + "import einops" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "import torch\n", + "import relax" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## L2C Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def gumbel_softmax(\n", + " key: jrand.PRNGKey, # Random key\n", + " logits: Array, # Logits for each class. Shape (batch_size, num_classes)\n", + " tau: float, # Temperature for the Gumbel softmax\n", + "):\n", + " \"\"\"The Gumbel softmax function.\"\"\"\n", + "\n", + " gumbel_noise = jrand.gumbel(key, shape=logits.shape)\n", + " y = logits + gumbel_noise\n", + " return jax.nn.softmax(y / tau, axis=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def sample_categorical(\n", + " key: jrand.PRNGKey, # Random key\n", + " logits: Array, # Logits for each class. Shape (batch_size, num_classes)\n", + " tau: float, # Temperature for the Gumbel softmax\n", + " training: bool = True, # Apply gumbel softmax if training\n", + "):\n", + " \"\"\"Sample from a categorical distribution.\"\"\"\n", + "\n", + " def sample_cat(key, logits):\n", + " cat = jrand.categorical(key, logits=logits, axis=-1)\n", + " return jax.nn.one_hot(cat, logits.shape[-1])\n", + "\n", + " return lax.cond(\n", + " training,\n", + " lambda _: gumbel_softmax(key, logits, tau=tau),\n", + " lambda _: sample_cat(key, logits),\n", + " None,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "logits = jnp.array([[2.0, 1.0, 0.1], [1.0, 2.0, 3.0]])\n", + "key = jrand.PRNGKey(0)\n", + "output = sample_categorical(key, logits, tau=0.5, training=True)\n", + "assert output.shape == logits.shape\n", + "assert jnp.allclose(output.sum(axis=-1), 1.0)\n", + "# low temperature -> one-hot\n", + "output = sample_categorical(key, logits, tau=0.01, training=True)\n", + "assert jnp.array_equal(\n", + " output.argmax(axis=-1), logits.argmax(axis=-1)\n", + ")\n", + "# high temperature -> uniform\n", + "output = sample_categorical(key, logits, tau=100, training=True)\n", + "assert jnp.max(output) - jnp.min(output) < 0.5\n", + "\n", + "output = sample_categorical(key, logits, tau=0.5, training=False)\n", + "assert output.shape == logits.shape\n", + "assert jnp.array_equal(\n", + " output.argmax(axis=-1), logits.argmax(axis=-1)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def sample_bernouli(\n", + " key: jrand.PRNGKey, # Random key\n", + " prob: Array, # Logits for each class. Shape (batch_size, 1)\n", + " tau: float, # Temperature for the Gumbel softmax\n", + " training: bool = True, # Apply gumbel softmax if training\n", + ") -> Array:\n", + " \"\"\"\"Sample from a bernouli distribution.\"\"\"\n", + "\n", + " def sample_ber(key, prob):\n", + " return jrand.bernoulli(key, p=prob).astype(prob.dtype)\n", + " \n", + " def gumbel_ber(key, prob, tau):\n", + " key_1, key_2 = jrand.split(key)\n", + " gumbel_1 = jrand.gumbel(key_1, shape=prob.shape)\n", + " gumbel_2 = jrand.gumbel(key_2, shape=prob.shape)\n", + " no_logits = (prob * jnp.exp(gumbel_1)) / tau\n", + " de_logits = no_logits + ((1. - prob) * jnp.exp(gumbel_2)) / tau\n", + " return no_logits / de_logits\n", + " \n", + " return lax.cond(\n", + " training,\n", + " lambda _: gumbel_ber(key, prob, tau),\n", + " lambda _: sample_ber(key, prob),\n", + " None,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class L2CModel(keras.Model):\n", + " def __init__(\n", + " self,\n", + " generator_layers: list[int],\n", + " selector_layers: list[int],\n", + " feature_indices: list[tuple[int, int]] = None,\n", + " pred_fn: Callable = None,\n", + " alpha: float = 1e-4, # Sparsity regularization\n", + " tau: float = 0.7,\n", + " seed: int = None,\n", + " **kwargs\n", + " ):\n", + " super().__init__(**kwargs)\n", + " self.generator_layers = generator_layers\n", + " self.selector_layers = selector_layers\n", + " self.feature_indices = feature_indices\n", + " self.pred_fn = pred_fn\n", + " self.tau = tau\n", + " self.alpha = alpha\n", + " seed = seed or get_config().global_seed\n", + " self.seed_generator = SeedGenerator(seed)\n", + "\n", + " def set_features_info(self, feature_indices: list[tuple[int, int]]):\n", + " self.feature_indices = feature_indices\n", + " # TODO: check if the feature indices are valid\n", + "\n", + " def set_pred_fn(self, pred_fn: Callable):\n", + " self.pred_fn = pred_fn\n", + "\n", + " def build(self, input_shape):\n", + " n_feats = len(self.feature_indices)\n", + " self.generator = MLP(\n", + " sizes=self.generator_layers,\n", + " output_size=input_shape[-1],\n", + " dropout_rate=0.0,\n", + " last_activation=\"linear\",\n", + " )\n", + " self.selector = MLP(\n", + " sizes=self.selector_layers,\n", + " output_size=n_feats,\n", + " dropout_rate=0.0,\n", + " last_activation=\"sigmoid\",\n", + " )\n", + "\n", + " def compute_l2c_loss(self, inputs, cfs, probs):\n", + " y_target = self.pred_fn(inputs).argmin(axis=-1)\n", + " y_pred = self.pred_fn(cfs)\n", + " validity_loss = keras.losses.sparse_categorical_crossentropy(\n", + " y_target, y_pred\n", + " ).mean()\n", + " sparsity = jnp.linalg.norm(probs, ord=1) * self.alpha\n", + " return validity_loss, sparsity\n", + " \n", + " def perturb(self, inputs, cfs, probs, i, start, end):\n", + " return cfs[:, start:end] * probs[:, i : i + 1] + inputs[:, start:end] * (1 - probs[:, i : i + 1])\n", + " \n", + " def forward(self, inputs, training=False):\n", + " select_probs = self.selector(inputs, training=training)\n", + " probs = sample_bernouli(\n", + " self.seed_generator.next(), select_probs, \n", + " tau=self.tau, training=training\n", + " )\n", + " cfs_logits = self.generator(inputs, training=training)\n", + " cfs = sample_categorical(\n", + " self.seed_generator.next(), cfs_logits, \n", + " tau=self.tau, training=training\n", + " )\n", + " cfs = jnp.concatenate([\n", + " self.perturb(inputs, cfs, probs, i, start, end)\n", + " for i, (start, end) in enumerate(self.feature_indices)\n", + " ], axis=-1,\n", + " )\n", + " return cfs, probs\n", + " \n", + " def call(self, inputs, training=False):\n", + " cfs, probs = self.forward(inputs, training=training)\n", + " # loss = self.compute_l2c_loss(inputs, cfs, probs)\n", + " validity_loss, sparsity = self.compute_l2c_loss(inputs, cfs, probs)\n", + " self.add_loss(validity_loss)\n", + " self.add_loss(sparsity)\n", + " return cfs \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Discretizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def qcut(\n", + " x: Array, # Input array\n", + " q: int, # Number of quantiles\n", + " axis: int = 0, # Axis to quantile\n", + ") -> tuple[Array, Array]: # (digitized array, quantiles)\n", + " \"\"\"Quantile binning.\"\"\"\n", + " \n", + " # Handle edge cases: empty array or single element\n", + " if x.size <= 1:\n", + " return jnp.zeros_like(x), jnp.array([])\n", + " quantiles = jnp.quantile(x, jnp.linspace(0, 1, q + 1)[1:-1], axis=axis)\n", + " digitized = jnp.digitize(x, quantiles)\n", + " return digitized, quantiles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "digitized, quantiles = qcut(jnp.arange(10), 4)\n", + "assert digitized.shape == (10,)\n", + "assert quantiles.shape == (3,)\n", + "assert jnp.allclose(\n", + " digitized, jnp.array([0,0,0,1,1,2,2,3,3,3])\n", + ")\n", + "\n", + "quantiles_true = jnp.array([0, 2.25, 4.5, 6.75, 9])\n", + "assert jnp.allclose(\n", + " quantiles, quantiles_true[1:-1]\n", + ")\n", + "x_empty = jnp.array([])\n", + "q = 2\n", + "digitized_empty, quantiles_empty = qcut(x_empty, q)\n", + "assert digitized_empty.size == 0 and quantiles_empty.size == 0\n", + "# Test with single element array\n", + "x_single = jnp.array([1])\n", + "digitized_single, quantiles_single = qcut(x_single, q)\n", + "assert digitized_single.size == 1 and quantiles_single.size == 0\n", + "\n", + "# Test with large q value\n", + "xs = jnp.array([1, 2, 3, 4, 5, 6])\n", + "q_large = 10\n", + "_, quantiles_large = qcut(xs, q_large)\n", + "assert len(quantiles_large) == q_large - 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def qcut_inverse(\n", + " digitized: Array, # Digitized One-Hot Encoding Array\n", + " quantiles: Array, # Quantiles\n", + ") -> Array:\n", + " \"\"\"Inverse of qcut.\"\"\"\n", + " \n", + " return digitized @ quantiles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "digitized, quantiles = qcut(jnp.arange(10), 4)\n", + "ohe_digitized = jax.nn.one_hot(digitized, 4)\n", + "quantiles_inv = qcut_inverse(ohe_digitized, jnp.arange(4))\n", + "assert quantiles_inv.shape == (10,)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def cut_quantiles(\n", + " quantiles: Array, # Quantiles\n", + " xs: Array, # Input array\n", + "):\n", + " quantiles = jnp.concatenate([\n", + " xs.min(axis=0, keepdims=True), \n", + " quantiles, \n", + " xs.max(axis=0, keepdims=True)\n", + " ])\n", + " quantiles = (quantiles[1:] + quantiles[:-1]) / 2\n", + " return quantiles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def discretize_xs(\n", + " xs: Array, # Input array\n", + " is_categorical_and_indices: list[tuple[bool, tuple[int, int]]], # Features list\n", + " q: int = 4, # Number of quantiles\n", + ") -> tuple[Array, list[Array], list[tuple[tuple[int, int], Array]]]: # (discretized array, indices_and_quantiles_and_mid)\n", + " \"\"\"Discretize continuous features.\"\"\"\n", + " \n", + " discretized_xs = []\n", + " indices_and_mid = []\n", + " quantiles_feats = []\n", + " discretized_start, discretized_end = 0, 0\n", + "\n", + " for is_categorical, (start, end) in is_categorical_and_indices:\n", + " if is_categorical:\n", + " discretized, quantiles, mid = xs[:, start:end], None, None\n", + " discretized_end += end - start\n", + " else:\n", + " discretized, quantiles = qcut(xs[:, start:end].reshape(-1), q=q)\n", + " mid = cut_quantiles(quantiles, xs[:, start])\n", + " discretized = jax.nn.one_hot(discretized, q)\n", + " discretized_end += discretized.shape[-1]\n", + " \n", + " discretized_xs.append(discretized)\n", + " quantiles_feats.append(quantiles)\n", + " indices_and_mid.append(\n", + " ((discretized_start, discretized_end), mid)\n", + " )\n", + " \n", + " discretized_start = discretized_end\n", + " discretized_xs = jnp.concatenate(discretized_xs, axis=-1)\n", + " return discretized_xs, quantiles_feats, indices_and_mid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dm = relax.load_data(\"dummy\")\n", + "xs, ys = dm['train']\n", + "is_categorical_and_indices = [\n", + " (feat.is_categorical, indices) for feat, indices in zip(dm.features, dm.features.feature_indices)\n", + "]\n", + "discretized_xs, quantiles_feats, indices_and_mid = discretize_xs(xs, is_categorical_and_indices)\n", + "assert discretized_xs.shape == (xs.shape[0], 4 * xs.shape[1])\n", + "assert len(quantiles_feats) == len(is_categorical_and_indices)\n", + "assert all(len(quantiles_feats[i]) == 3 for i in range(len(quantiles_feats)))\n", + "assert len(indices_and_mid) == len(is_categorical_and_indices)\n", + "assert all(len(indices_and_mid[i][1]) == 4 for i in range(len(indices_and_mid)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class Discretizer:\n", + " \"\"\"Discretize continuous features.\"\"\"\n", + " \n", + " def __init__(\n", + " self, \n", + " is_cat_and_indices: list[tuple[bool, tuple[int, int]]], # Features list\n", + " q: int = 4 # Number of quantiles\n", + " ):\n", + " self.is_cat_and_indices = is_cat_and_indices\n", + " self.q = q\n", + "\n", + " def fit(self, xs: Array):\n", + " _, self.quantiles, self.indices_and_mid_quantiles = discretize_xs(\n", + " xs, self.is_cat_and_indices, self.q\n", + " )\n", + "\n", + " def transform(self, xs: Array):\n", + " digitized_xs = []\n", + " for quantiles, (_, (start, end)) in zip(self.quantiles, self.is_cat_and_indices):\n", + " if quantiles is None: \n", + " digitized = xs[:, start:end]\n", + " else:\n", + " digitized = jnp.digitize(xs[:, start], quantiles)\n", + " digitized = jax.nn.one_hot(digitized, self.q)\n", + " digitized_xs.append(digitized)\n", + " return jnp.concatenate(digitized_xs, axis=-1)\n", + "\n", + " def fit_transform(self, xs: Array):\n", + " self.fit(xs)\n", + " return self.transform(xs)\n", + "\n", + " def inverse_transform(self, xs: Array):\n", + " continutized_xs = []\n", + " for (start, end), mid_quantiles in self.indices_and_mid_quantiles:\n", + " if mid_quantiles is None:\n", + " cont_feat = xs[:, start:end]\n", + " else:\n", + " cont_feat = qcut_inverse(xs[:, start:end], mid_quantiles).reshape(-1, 1)\n", + " continutized_xs.append(cont_feat)\n", + " return jnp.concatenate(continutized_xs, axis=-1)\n", + " \n", + " def get_pred_fn(self, pred_fn: Callable[[Array], Array]):\n", + " def _pred_fn(xs: Array):\n", + " return pred_fn(self.inverse_transform(xs))\n", + " return _pred_fn\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dis = Discretizer(is_categorical_and_indices)\n", + "dis.fit(xs)\n", + "digitized_xs_1 = dis.transform(xs)\n", + "assert jnp.array_equal(discretized_xs, digitized_xs_1)\n", + "inversed_xs = dis.inverse_transform(digitized_xs_1)\n", + "assert xs.shape == inversed_xs.shape\n", + "assert jnp.unique(inversed_xs).size == xs.shape[1] * 4\n", + "\n", + "ml_module = relax.load_ml_module(\"dummy\")\n", + "pred_fn = dis.get_pred_fn(ml_module.pred_fn)\n", + "y = pred_fn(digitized_xs_1)\n", + "assert y.shape == (xs.shape[0], 2)\n", + "\n", + "def f(x, y):\n", + " y_pred = pred_fn(x)\n", + " return jnp.mean((y_pred - y) ** 2)\n", + "\n", + "grad = jax.grad(f)(digitized_xs_1, ys)\n", + "assert grad.shape == digitized_xs_1.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## L2C Module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class L2CConfig(BaseConfig):\n", + " generator_layers: list[int] = Field(\n", + " [64, 64, 64], description=\"Generator MLP layers.\"\n", + " )\n", + " selector_layers: list[int] = Field(\n", + " [64], description=\"Selector MLP layers.\"\n", + " )\n", + " lr: float = Field(1e-3, description=\"Model learning rate.\")\n", + " opt_name: str = Field(\"adam\", description=\"Optimizer name of training L2C.\")\n", + " alpha: float = Field(1e-4, description=\"Sparsity regularization.\")\n", + " tau: float = Field(0.7, description=\"Temperature for the Gumbel softmax.\")\n", + " q: int = Field(4, description=\"Number of quantiles.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class L2C(ParametricCFModule):\n", + " def __init__(\n", + " self,\n", + " config: Dict | L2CConfig = None,\n", + " l2c_model: L2CModel = None,\n", + " name: str = \"l2c\",\n", + " ):\n", + " if config is None:\n", + " config = L2CConfig()\n", + " config = validate_configs(config, L2CConfig)\n", + " name = name or \"l2c\"\n", + " self.l2c_model = l2c_model\n", + " super().__init__(config=config, name=name)\n", + "\n", + " def train(\n", + " self, \n", + " data: DataModule, \n", + " pred_fn: Callable,\n", + " batch_size: int = 128,\n", + " epochs: int = 10,\n", + " **fit_kwargs\n", + " ):\n", + " if not isinstance(data, DataModule):\n", + " raise ValueError(f\"Only support `data` to be `DataModule`, \"\n", + " f\"got type=`{type(data).__name__}` instead.\")\n", + " \n", + " xs_train, ys_train = data['train']\n", + " self.discretizer = Discretizer(\n", + " [(feat.is_categorical, indices) for feat, indices in zip(data.features, data.features.feature_indices)],\n", + " q=self.config.q\n", + " )\n", + " discretized_xs_train = self.discretizer.fit_transform(xs_train)\n", + " pred_fn = self.discretizer.get_pred_fn(pred_fn)\n", + " features_indices = [indices for indices, _ in self.discretizer.indices_and_mid_quantiles]\n", + "\n", + " self.l2c_model = L2CModel(\n", + " generator_layers=self.config.generator_layers,\n", + " selector_layers=self.config.selector_layers,\n", + " feature_indices=features_indices,\n", + " pred_fn=pred_fn,\n", + " alpha=self.config.alpha,\n", + " tau=self.config.tau,\n", + " )\n", + " self.l2c_model.compile(\n", + " optimizer=keras.optimizers.get({\n", + " 'class_name': self.config.opt_name, \n", + " 'config': {'learning_rate': self.config.lr}\n", + " }),\n", + " loss=None\n", + " )\n", + " self.l2c_model.fit(\n", + " discretized_xs_train, ys_train,\n", + " epochs=epochs,\n", + " batch_size=batch_size,\n", + " **fit_kwargs\n", + " )\n", + " self._is_trained = True\n", + " return self\n", + " \n", + " @auto_reshaping('x')\n", + " def generate_cf(\n", + " self, \n", + " x: Array, \n", + " **kwargs\n", + " ) -> Array:\n", + " # TODO: Does not support passing apply_constraints \n", + " discretized_x = self.discretizer.transform(x)\n", + " cfs, probs = self.l2c_model.forward(discretized_x, training=False)\n", + " return self.discretizer.inverse_transform(cfs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 9ms/step - loss: 1.1063 \n", + "Epoch 2/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.5560 \n", + "Epoch 3/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.3971 \n", + "Epoch 4/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.3766 \n", + "Epoch 5/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.3716 \n", + "Epoch 6/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.3701 \n", + "Epoch 7/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.3590 \n", + "Epoch 8/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.3423 \n", + "Epoch 9/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.3419 \n", + "Epoch 10/10\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.3363 \n" + ] + } + ], + "source": [ + "dm = relax.load_data('adult')\n", + "ml_module = relax.load_ml_module('adult')\n", + "l2c = L2C()\n", + "exp = relax.generate_cf_explanations(\n", + " l2c, dm, ml_module.pred_fn,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/relax/_modidx.py b/relax/_modidx.py index 5a56a81..fb3be90 100644 --- a/relax/_modidx.py +++ b/relax/_modidx.py @@ -594,6 +594,42 @@ 'relax.methods.dice.DiverseCFConfig': ('methods/dice.html#diversecfconfig', 'relax/methods/dice.py'), 'relax.methods.dice._diverse_cf': ('methods/dice.html#_diverse_cf', 'relax/methods/dice.py'), 'relax.methods.dice.dpp_style_vmap': ('methods/dice.html#dpp_style_vmap', 'relax/methods/dice.py')}, + 'relax.methods.l2c': { 'relax.methods.l2c.Discretizer': ('methods/l2c.html#discretizer', 'relax/methods/l2c.py'), + 'relax.methods.l2c.Discretizer.__init__': ( 'methods/l2c.html#discretizer.__init__', + 'relax/methods/l2c.py'), + 'relax.methods.l2c.Discretizer.fit': ('methods/l2c.html#discretizer.fit', 'relax/methods/l2c.py'), + 'relax.methods.l2c.Discretizer.fit_transform': ( 'methods/l2c.html#discretizer.fit_transform', + 'relax/methods/l2c.py'), + 'relax.methods.l2c.Discretizer.get_pred_fn': ( 'methods/l2c.html#discretizer.get_pred_fn', + 'relax/methods/l2c.py'), + 'relax.methods.l2c.Discretizer.inverse_transform': ( 'methods/l2c.html#discretizer.inverse_transform', + 'relax/methods/l2c.py'), + 'relax.methods.l2c.Discretizer.transform': ( 'methods/l2c.html#discretizer.transform', + 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2C': ('methods/l2c.html#l2c', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2C.__init__': ('methods/l2c.html#l2c.__init__', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2C.generate_cf': ('methods/l2c.html#l2c.generate_cf', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2C.train': ('methods/l2c.html#l2c.train', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CConfig': ('methods/l2c.html#l2cconfig', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel': ('methods/l2c.html#l2cmodel', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel.__init__': ('methods/l2c.html#l2cmodel.__init__', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel.build': ('methods/l2c.html#l2cmodel.build', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel.call': ('methods/l2c.html#l2cmodel.call', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel.compute_l2c_loss': ( 'methods/l2c.html#l2cmodel.compute_l2c_loss', + 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel.forward': ('methods/l2c.html#l2cmodel.forward', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel.perturb': ('methods/l2c.html#l2cmodel.perturb', 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel.set_features_info': ( 'methods/l2c.html#l2cmodel.set_features_info', + 'relax/methods/l2c.py'), + 'relax.methods.l2c.L2CModel.set_pred_fn': ( 'methods/l2c.html#l2cmodel.set_pred_fn', + 'relax/methods/l2c.py'), + 'relax.methods.l2c.cut_quantiles': ('methods/l2c.html#cut_quantiles', 'relax/methods/l2c.py'), + 'relax.methods.l2c.discretize_xs': ('methods/l2c.html#discretize_xs', 'relax/methods/l2c.py'), + 'relax.methods.l2c.gumbel_softmax': ('methods/l2c.html#gumbel_softmax', 'relax/methods/l2c.py'), + 'relax.methods.l2c.qcut': ('methods/l2c.html#qcut', 'relax/methods/l2c.py'), + 'relax.methods.l2c.qcut_inverse': ('methods/l2c.html#qcut_inverse', 'relax/methods/l2c.py'), + 'relax.methods.l2c.sample_bernouli': ('methods/l2c.html#sample_bernouli', 'relax/methods/l2c.py'), + 'relax.methods.l2c.sample_categorical': ('methods/l2c.html#sample_categorical', 'relax/methods/l2c.py')}, 'relax.methods.proto': { 'relax.methods.proto.ProtoCF': ('methods/proto.html#protocf', 'relax/methods/proto.py'), 'relax.methods.proto.ProtoCF.__init__': ( 'methods/proto.html#protocf.__init__', 'relax/methods/proto.py'), diff --git a/relax/methods/l2c.py b/relax/methods/l2c.py new file mode 100644 index 0000000..ab15b38 --- /dev/null +++ b/relax/methods/l2c.py @@ -0,0 +1,363 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/methods/09_l2c.ipynb. + +# %% ../../nbs/methods/09_l2c.ipynb 3 +from __future__ import annotations +from ..import_essentials import * +from .base import ParametricCFModule +from ..base import BaseConfig +from ..utils import * +from ..data_utils import Feature, FeaturesList +from ..ml_model import MLP, MLPBlock +from ..data_module import DataModule +from keras_core.random import SeedGenerator +import einops + +# %% auto 0 +__all__ = ['gumbel_softmax', 'sample_categorical', 'sample_bernouli', 'L2CModel', 'qcut', 'qcut_inverse', 'cut_quantiles', + 'discretize_xs', 'Discretizer', 'L2CConfig', 'L2C'] + +# %% ../../nbs/methods/09_l2c.ipynb 6 +def gumbel_softmax( + key: jrand.PRNGKey, # Random key + logits: Array, # Logits for each class. Shape (batch_size, num_classes) + tau: float, # Temperature for the Gumbel softmax +): + """The Gumbel softmax function.""" + + gumbel_noise = jrand.gumbel(key, shape=logits.shape) + y = logits + gumbel_noise + return jax.nn.softmax(y / tau, axis=-1) + +# %% ../../nbs/methods/09_l2c.ipynb 7 +def sample_categorical( + key: jrand.PRNGKey, # Random key + logits: Array, # Logits for each class. Shape (batch_size, num_classes) + tau: float, # Temperature for the Gumbel softmax + training: bool = True, # Apply gumbel softmax if training +): + """Sample from a categorical distribution.""" + + def sample_cat(key, logits): + cat = jrand.categorical(key, logits=logits, axis=-1) + return jax.nn.one_hot(cat, logits.shape[-1]) + + return lax.cond( + training, + lambda _: gumbel_softmax(key, logits, tau=tau), + lambda _: sample_cat(key, logits), + None, + ) + +# %% ../../nbs/methods/09_l2c.ipynb 9 +def sample_bernouli( + key: jrand.PRNGKey, # Random key + prob: Array, # Logits for each class. Shape (batch_size, 1) + tau: float, # Temperature for the Gumbel softmax + training: bool = True, # Apply gumbel softmax if training +) -> Array: + """"Sample from a bernouli distribution.""" + + def sample_ber(key, prob): + return jrand.bernoulli(key, p=prob).astype(prob.dtype) + + def gumbel_ber(key, prob, tau): + key_1, key_2 = jrand.split(key) + gumbel_1 = jrand.gumbel(key_1, shape=prob.shape) + gumbel_2 = jrand.gumbel(key_2, shape=prob.shape) + no_logits = (prob * jnp.exp(gumbel_1)) / tau + de_logits = no_logits + ((1. - prob) * jnp.exp(gumbel_2)) / tau + return no_logits / de_logits + + return lax.cond( + training, + lambda _: gumbel_ber(key, prob, tau), + lambda _: sample_ber(key, prob), + None, + ) + +# %% ../../nbs/methods/09_l2c.ipynb 10 +class L2CModel(keras.Model): + def __init__( + self, + generator_layers: list[int], + selector_layers: list[int], + feature_indices: list[tuple[int, int]] = None, + pred_fn: Callable = None, + alpha: float = 1e-4, # Sparsity regularization + tau: float = 0.7, + seed: int = None, + **kwargs + ): + super().__init__(**kwargs) + self.generator_layers = generator_layers + self.selector_layers = selector_layers + self.feature_indices = feature_indices + self.pred_fn = pred_fn + self.tau = tau + self.alpha = alpha + seed = seed or get_config().global_seed + self.seed_generator = SeedGenerator(seed) + + def set_features_info(self, feature_indices: list[tuple[int, int]]): + self.feature_indices = feature_indices + # TODO: check if the feature indices are valid + + def set_pred_fn(self, pred_fn: Callable): + self.pred_fn = pred_fn + + def build(self, input_shape): + n_feats = len(self.feature_indices) + self.generator = MLP( + sizes=self.generator_layers, + output_size=input_shape[-1], + dropout_rate=0.0, + last_activation="linear", + ) + self.selector = MLP( + sizes=self.selector_layers, + output_size=n_feats, + dropout_rate=0.0, + last_activation="sigmoid", + ) + + def compute_l2c_loss(self, inputs, cfs, probs): + y_target = self.pred_fn(inputs).argmin(axis=-1) + y_pred = self.pred_fn(cfs) + validity_loss = keras.losses.sparse_categorical_crossentropy( + y_target, y_pred + ).mean() + sparsity = jnp.linalg.norm(probs, ord=1) * self.alpha + return validity_loss, sparsity + + def perturb(self, inputs, cfs, probs, i, start, end): + return cfs[:, start:end] * probs[:, i : i + 1] + inputs[:, start:end] * (1 - probs[:, i : i + 1]) + + def forward(self, inputs, training=False): + select_probs = self.selector(inputs, training=training) + probs = sample_bernouli( + self.seed_generator.next(), select_probs, + tau=self.tau, training=training + ) + cfs_logits = self.generator(inputs, training=training) + cfs = sample_categorical( + self.seed_generator.next(), cfs_logits, + tau=self.tau, training=training + ) + cfs = jnp.concatenate([ + self.perturb(inputs, cfs, probs, i, start, end) + for i, (start, end) in enumerate(self.feature_indices) + ], axis=-1, + ) + return cfs, probs + + def call(self, inputs, training=False): + cfs, probs = self.forward(inputs, training=training) + # loss = self.compute_l2c_loss(inputs, cfs, probs) + validity_loss, sparsity = self.compute_l2c_loss(inputs, cfs, probs) + self.add_loss(validity_loss) + self.add_loss(sparsity) + return cfs + + +# %% ../../nbs/methods/09_l2c.ipynb 12 +def qcut( + x: Array, # Input array + q: int, # Number of quantiles + axis: int = 0, # Axis to quantile +) -> tuple[Array, Array]: # (digitized array, quantiles) + """Quantile binning.""" + + # Handle edge cases: empty array or single element + if x.size <= 1: + return jnp.zeros_like(x), jnp.array([]) + quantiles = jnp.quantile(x, jnp.linspace(0, 1, q + 1)[1:-1], axis=axis) + digitized = jnp.digitize(x, quantiles) + return digitized, quantiles + +# %% ../../nbs/methods/09_l2c.ipynb 14 +def qcut_inverse( + digitized: Array, # Digitized One-Hot Encoding Array + quantiles: Array, # Quantiles +) -> Array: + """Inverse of qcut.""" + + return digitized @ quantiles + +# %% ../../nbs/methods/09_l2c.ipynb 16 +def cut_quantiles( + quantiles: Array, # Quantiles + xs: Array, # Input array +): + quantiles = jnp.concatenate([ + xs.min(axis=0, keepdims=True), + quantiles, + xs.max(axis=0, keepdims=True) + ]) + quantiles = (quantiles[1:] + quantiles[:-1]) / 2 + return quantiles + +# %% ../../nbs/methods/09_l2c.ipynb 17 +def discretize_xs( + xs: Array, # Input array + is_categorical_and_indices: list[tuple[bool, tuple[int, int]]], # Features list + q: int = 4, # Number of quantiles +) -> tuple[Array, list[Array], list[tuple[tuple[int, int], Array]]]: # (discretized array, indices_and_quantiles_and_mid) + """Discretize continuous features.""" + + discretized_xs = [] + indices_and_mid = [] + quantiles_feats = [] + discretized_start, discretized_end = 0, 0 + + for is_categorical, (start, end) in is_categorical_and_indices: + if is_categorical: + discretized, quantiles, mid = xs[:, start:end], None, None + discretized_end += end - start + else: + discretized, quantiles = qcut(xs[:, start:end].reshape(-1), q=q) + mid = cut_quantiles(quantiles, xs[:, start]) + discretized = jax.nn.one_hot(discretized, q) + discretized_end += discretized.shape[-1] + + discretized_xs.append(discretized) + quantiles_feats.append(quantiles) + indices_and_mid.append( + ((discretized_start, discretized_end), mid) + ) + + discretized_start = discretized_end + discretized_xs = jnp.concatenate(discretized_xs, axis=-1) + return discretized_xs, quantiles_feats, indices_and_mid + +# %% ../../nbs/methods/09_l2c.ipynb 19 +class Discretizer: + """Discretize continuous features.""" + + def __init__( + self, + is_cat_and_indices: list[tuple[bool, tuple[int, int]]], # Features list + q: int = 4 # Number of quantiles + ): + self.is_cat_and_indices = is_cat_and_indices + self.q = q + + def fit(self, xs: Array): + _, self.quantiles, self.indices_and_mid_quantiles = discretize_xs( + xs, self.is_cat_and_indices, self.q + ) + + def transform(self, xs: Array): + digitized_xs = [] + for quantiles, (_, (start, end)) in zip(self.quantiles, self.is_cat_and_indices): + if quantiles is None: + digitized = xs[:, start:end] + else: + digitized = jnp.digitize(xs[:, start], quantiles) + digitized = jax.nn.one_hot(digitized, self.q) + digitized_xs.append(digitized) + return jnp.concatenate(digitized_xs, axis=-1) + + def fit_transform(self, xs: Array): + self.fit(xs) + return self.transform(xs) + + def inverse_transform(self, xs: Array): + continutized_xs = [] + for (start, end), mid_quantiles in self.indices_and_mid_quantiles: + if mid_quantiles is None: + cont_feat = xs[:, start:end] + else: + cont_feat = qcut_inverse(xs[:, start:end], mid_quantiles).reshape(-1, 1) + continutized_xs.append(cont_feat) + return jnp.concatenate(continutized_xs, axis=-1) + + def get_pred_fn(self, pred_fn: Callable[[Array], Array]): + def _pred_fn(xs: Array): + return pred_fn(self.inverse_transform(xs)) + return _pred_fn + + +# %% ../../nbs/methods/09_l2c.ipynb 22 +class L2CConfig(BaseConfig): + generator_layers: list[int] = Field( + [64, 64, 64], description="Generator MLP layers." + ) + selector_layers: list[int] = Field( + [64], description="Selector MLP layers." + ) + lr: float = Field(1e-3, description="Model learning rate.") + opt_name: str = Field("adam", description="Optimizer name of training L2C.") + alpha: float = Field(1e-4, description="Sparsity regularization.") + tau: float = Field(0.7, description="Temperature for the Gumbel softmax.") + q: int = Field(4, description="Number of quantiles.") + +# %% ../../nbs/methods/09_l2c.ipynb 23 +class L2C(ParametricCFModule): + def __init__( + self, + config: Dict | L2CConfig = None, + l2c_model: L2CModel = None, + name: str = "l2c", + ): + if config is None: + config = L2CConfig() + config = validate_configs(config, L2CConfig) + name = name or "l2c" + self.l2c_model = l2c_model + super().__init__(config=config, name=name) + + def train( + self, + data: DataModule, + pred_fn: Callable, + batch_size: int = 128, + epochs: int = 10, + **fit_kwargs + ): + if not isinstance(data, DataModule): + raise ValueError(f"Only support `data` to be `DataModule`, " + f"got type=`{type(data).__name__}` instead.") + + xs_train, ys_train = data['train'] + self.discretizer = Discretizer( + [(feat.is_categorical, indices) for feat, indices in zip(data.features, data.features.feature_indices)], + q=self.config.q + ) + discretized_xs_train = self.discretizer.fit_transform(xs_train) + pred_fn = self.discretizer.get_pred_fn(pred_fn) + features_indices = [indices for indices, _ in self.discretizer.indices_and_mid_quantiles] + + self.l2c_model = L2CModel( + generator_layers=self.config.generator_layers, + selector_layers=self.config.selector_layers, + feature_indices=features_indices, + pred_fn=pred_fn, + alpha=self.config.alpha, + tau=self.config.tau, + ) + self.l2c_model.compile( + optimizer=keras.optimizers.get({ + 'class_name': self.config.opt_name, + 'config': {'learning_rate': self.config.lr} + }), + loss=None + ) + self.l2c_model.fit( + discretized_xs_train, ys_train, + epochs=epochs, + batch_size=batch_size, + **fit_kwargs + ) + self._is_trained = True + return self + + @auto_reshaping('x') + def generate_cf( + self, + x: Array, + **kwargs + ) -> Array: + # TODO: Does not support passing apply_constraints + discretized_x = self.discretizer.transform(x) + cfs, probs = self.l2c_model.forward(discretized_x, training=False) + return self.discretizer.inverse_transform(cfs)