Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorials on ReLax as a Recourse Library #19

Merged
merged 4 commits into from
Nov 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 125 additions & 6 deletions nbs/tutorials/methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,43 @@
"cell_type": "markdown",
Praneyg marked this conversation as resolved.
Show resolved Hide resolved
Praneyg marked this conversation as resolved.
Show resolved Hide resolved
Praneyg marked this conversation as resolved.
Show resolved Hide resolved
Praneyg marked this conversation as resolved.
Show resolved Hide resolved
"metadata": {},
"source": [
"`ReLax` is a recourse explanation library which provides implementations of various recourse methods.\n",
"In other words, you can use implemented methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n",
"`ReLax` contains implementations of various recourse methods, which are decoupled from the rest of `ReLax` library.\n",
"We give users flexibility on how to use `ReLax`: \n",
"\n",
"At a high level, you can use the implemented methods in `ReLax` to generate a recourse explanation via three lines of code:\n",
"* You can use the recourse pipeline in `ReLax` (\"one-liner\" for easy benchmarking recourse methods; see this [tutorial](getting_started.ipynb)).\n",
"* You can use all of the recourse methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n",
"\n",
"In this tutorial, we uncover the possibility of the second option by using recourse methods under `relax.methods` \n",
"for debugging, diagnosing, interpreting your JAX models.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Types of Recourse Methods\n",
"\n",
"1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include `VanillaCF`, `DiverseCF` and `GrowingSphere` . These methods inherit from `CFModule`.\n",
"\n",
"2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include `ProtoCF`, `CCHVAE` and `CLUE`. These methods inherit from `ParametricCFModule `.\n",
"\n",
"3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include `CounterNet` and `VAECF`. These methods inherit from `ParametricCFModule`.\n",
"\n",
"\n",
"|Method Type | Learned Parameters | Training Required | Example Methods | \n",
"|-----|:-----|:---:|:-----:|\n",
"|Non-parametric | None |No |`VanillaCF`, `DiverseCF`, `GrowingSphere` |\n",
"|Semi-parametric| Some (θ) |Modest amount |`ProtoCF`, `CCHVAE`, `CLUE` |\n",
"|Parametric|Full generator model (φ)|Substantial amount|`CounterNet`, `VAECF` |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Basic Usages\n",
"\n",
"At a high level, you can use the implemented methods in `ReLax` to generate *one* recourse explanation via three lines of code:\n",
"\n",
"```python\n",
"from relax.methods import VanillaCF\n",
Expand All @@ -30,11 +63,97 @@
"...\n",
"import functools as ft\n",
"\n",
"generate_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n",
"vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n",
"# xs is a batched data. Shape: `(N, K)`\n",
"cfs = jax.vmap(generate_fn)(xs)\n",
"cfs = jax.vmap(vcf_gen_fn)(xs)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use parametric and semi-parametric methods, you can first train the model\n",
"by calling `ParametricCF.train`, and then generate recourse explanations.\n",
"Here is an example of using `ReLax` for `CCHVAE`.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"from relax.methods import CCHVAE\n",
"\n",
"cchvae = CCHVAE()\n",
"cchvae.train(train_data) # Train CVAE before generation\n",
"cf = cchvae.generate_cf(x, pred_fn=pred_fn) \n",
"```\n",
"\n",
"Or generate a batch of recourse explanation via the `jax.vmap` primitive:\n",
"\n",
"```python\n",
"...\n",
"import functools as ft\n",
"\n",
"cchvae_gen_fn = ft.partial(cchvae.generate_cf, pred_fn=pred_fn)\n",
"cfs = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactuals\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Config Recourse Methods\n",
"\n",
"Each recourse method in `ReLax` has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.\n",
"\n",
"For example, to configure `VanillaCF`:\n",
"\n",
"```Python\n",
"from relax.methods import VanillaCF \n",
"from relax.methods.vanilla import VanillaCFConfig\n",
"\n",
"config = VanillaCFConfig(\n",
" n_steps=100,\n",
" lr=0.1,\n",
" lambda_=0.1\n",
")\n",
"\n",
"vcf = VanillaCF(config)\n",
"\n",
"```\n",
"Each Config class inherits from a `BaseConfig` that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.\n",
"\n",
"See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from `relax.methods.[method_name]`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, we can also specify this config via a dictionary.\n",
"\n",
"```Python\n",
"from relax.methods import VanillaCF\n",
"\n",
"config = {\n",
" \"n_steps\": 10, \n",
" \"lambda_\": 0.1,\n",
" \"lr\": 0.1 \n",
"}\n",
"\n",
"vcf = VanillaCF(config)\n",
"```\n",
"\n",
"This config dictionary is passed to VanillaCF's __init__ method, which will set the specified parameters. Now our `VanillaCF` instance is configured to:\n",
"\n",
" * Number 10 optimization steps (n_steps=100)\n",
" * Use 0.1 validity regularization for counterfactuals (lambda_=0.1)\n",
" * Use a learning rate of 0.1 for optimization (lr=0.1)"
]
}
],
"metadata": {
Expand All @@ -45,5 +164,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}