Skip to content

Commit

Permalink
Merge pull request #19 from BirkhoffG/tutorial
Browse files Browse the repository at this point in the history
Tutorials on `ReLax` as a Recourse Library
  • Loading branch information
BirkhoffG authored Nov 9, 2023
2 parents 5134914 + fa8a532 commit 51d55df
Showing 1 changed file with 125 additions and 6 deletions.
131 changes: 125 additions & 6 deletions nbs/tutorials/methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,43 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"`ReLax` is a recourse explanation library which provides implementations of various recourse methods.\n",
"In other words, you can use implemented methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n",
"`ReLax` contains implementations of various recourse methods, which are decoupled from the rest of `ReLax` library.\n",
"We give users flexibility on how to use `ReLax`: \n",
"\n",
"At a high level, you can use the implemented methods in `ReLax` to generate a recourse explanation via three lines of code:\n",
"* You can use the recourse pipeline in `ReLax` (\"one-liner\" for easy benchmarking recourse methods; see this [tutorial](getting_started.ipynb)).\n",
"* You can use all of the recourse methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n",
"\n",
"In this tutorial, we uncover the possibility of the second option by using recourse methods under `relax.methods` \n",
"for debugging, diagnosing, interpreting your JAX models.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Types of Recourse Methods\n",
"\n",
"1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include `VanillaCF`, `DiverseCF` and `GrowingSphere` . These methods inherit from `CFModule`.\n",
"\n",
"2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include `ProtoCF`, `CCHVAE` and `CLUE`. These methods inherit from `ParametricCFModule `.\n",
"\n",
"3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include `CounterNet` and `VAECF`. These methods inherit from `ParametricCFModule`.\n",
"\n",
"\n",
"|Method Type | Learned Parameters | Training Required | Example Methods | \n",
"|-----|:-----|:---:|:-----:|\n",
"|Non-parametric | None |No |`VanillaCF`, `DiverseCF`, `GrowingSphere` |\n",
"|Semi-parametric| Some (θ) |Modest amount |`ProtoCF`, `CCHVAE`, `CLUE` |\n",
"|Parametric|Full generator model (φ)|Substantial amount|`CounterNet`, `VAECF` |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Basic Usages\n",
"\n",
"At a high level, you can use the implemented methods in `ReLax` to generate *one* recourse explanation via three lines of code:\n",
"\n",
"```python\n",
"from relax.methods import VanillaCF\n",
Expand All @@ -30,11 +63,97 @@
"...\n",
"import functools as ft\n",
"\n",
"generate_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n",
"vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n",
"# xs is a batched data. Shape: `(N, K)`\n",
"cfs = jax.vmap(generate_fn)(xs)\n",
"cfs = jax.vmap(vcf_gen_fn)(xs)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use parametric and semi-parametric methods, you can first train the model\n",
"by calling `ParametricCF.train`, and then generate recourse explanations.\n",
"Here is an example of using `ReLax` for `CCHVAE`.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"from relax.methods import CCHVAE\n",
"\n",
"cchvae = CCHVAE()\n",
"cchvae.train(train_data) # Train CVAE before generation\n",
"cf = cchvae.generate_cf(x, pred_fn=pred_fn) \n",
"```\n",
"\n",
"Or generate a batch of recourse explanation via the `jax.vmap` primitive:\n",
"\n",
"```python\n",
"...\n",
"import functools as ft\n",
"\n",
"cchvae_gen_fn = ft.partial(cchvae.generate_cf, pred_fn=pred_fn)\n",
"cfs = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactuals\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Config Recourse Methods\n",
"\n",
"Each recourse method in `ReLax` has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.\n",
"\n",
"For example, to configure `VanillaCF`:\n",
"\n",
"```Python\n",
"from relax.methods import VanillaCF \n",
"from relax.methods.vanilla import VanillaCFConfig\n",
"\n",
"config = VanillaCFConfig(\n",
" n_steps=100,\n",
" lr=0.1,\n",
" lambda_=0.1\n",
")\n",
"\n",
"vcf = VanillaCF(config)\n",
"\n",
"```\n",
"Each Config class inherits from a `BaseConfig` that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.\n",
"\n",
"See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from `relax.methods.[method_name]`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, we can also specify this config via a dictionary.\n",
"\n",
"```Python\n",
"from relax.methods import VanillaCF\n",
"\n",
"config = {\n",
" \"n_steps\": 10, \n",
" \"lambda_\": 0.1,\n",
" \"lr\": 0.1 \n",
"}\n",
"\n",
"vcf = VanillaCF(config)\n",
"```\n",
"\n",
"This config dictionary is passed to VanillaCF's __init__ method, which will set the specified parameters. Now our `VanillaCF` instance is configured to:\n",
"\n",
" * Number 10 optimization steps (n_steps=100)\n",
" * Use 0.1 validity regularization for counterfactuals (lambda_=0.1)\n",
" * Use a learning rate of 0.1 for optimization (lr=0.1)"
]
}
],
"metadata": {
Expand All @@ -45,5 +164,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

0 comments on commit 51d55df

Please sign in to comment.