Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorials on ReLax as a Recourse Library #19

Merged
merged 4 commits into from
Nov 9, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 142 additions & 6 deletions nbs/tutorials/methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,43 @@
"cell_type": "markdown",
Praneyg marked this conversation as resolved.
Show resolved Hide resolved
Praneyg marked this conversation as resolved.
Show resolved Hide resolved
Praneyg marked this conversation as resolved.
Show resolved Hide resolved
Praneyg marked this conversation as resolved.
Show resolved Hide resolved
"metadata": {},
"source": [
"`ReLax` is a recourse explanation library which provides implementations of various recourse methods.\n",
"In other words, you can use implemented methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n",
"`ReLax` contains implementations of various recourse methods, which are decoupled from the rest of `ReLax` library.\n",
"We give users flexibility on how to use `ReLax`: \n",
"\n",
"At a high level, you can use the implemented methods in `ReLax` to generate a recourse explanation via three lines of code:\n",
"* You can use the recourse pipeline in `ReLax` (\"one-liner\" for easy benchmarking recourse methods; see this [tutorial](getting_started.ipynb)).\n",
"* You can use all of the recourse methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n",
"\n",
"In this tutorial, we uncover the possibility of the second option by using recourse methods under `relax.methods` \n",
"for debugging, diagnosing, interpreting your JAX models.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Types of Recourse Methods\n",
"\n",
"1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include VanillaCF and GrowingSpheres. These methods inherit from NonParametricCFModule.\n",
"\n",
"2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include DiverseCF and ProtoCF. These methods inherit from SemiParametricCFModule.\n",
"\n",
"3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include CounterNet, CCHVAE, VAECF and CLUE. These methods inherit from ParametricCFModule.\n",
"\n",
"\n",
"|Method Type | Learned Parameters | Training Required | Example Methods | \n",
"|-----|:-----|:---:|:-----:|\n",
"|Non-parametric |None |No |VanillaCF, GrowingSpheres |\n",
"|Semi-parametric|Some (θ)|Modest amount |DiverseCF, ProtoCF |\n",
"|Parametric|Full generator model (φ)|Substantial amount|CounterNet, CCHVAE, VAECF, CLUE|"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Basis \n",
"\n",
"At a high level, you can use the implemented methods in `ReLax` to generate *one* recourse explanation via three lines of code:\n",
"\n",
"```python\n",
"from relax.methods import VanillaCF\n",
Expand All @@ -30,11 +63,114 @@
"...\n",
"import functools as ft\n",
"\n",
"generate_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n",
"vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n",
"# xs is a batched data. Shape: `(N, K)`\n",
"cfs = jax.vmap(generate_fn)(xs)\n",
"cfs = jax.vmap(vcf_gen_fn)(xs)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Example of using ReLax for parametric methods (using CCHVAE)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"from relax.methods import CCHVAE\n",
"\n",
"cchvae = CCHVAE()\n",
"# x is one data point. Shape: `(K)` or `(1, K)`\n",
"cf = vcf.generate_cf(x, pred_fn=pred_fn)\n",
"```\n",
"\n",
"Or generate a batch of recourse explanation via the `jax.vmap` primitive:\n",
"\n",
"```python\n",
"...\n",
"import functools as ft\n",
"\n",
"cchvae_gen_fn = ft.partial(cchvae.generate_cf, pred_fn=pred_fn)\n",
"cfs = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactuals\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Config Recourse Methods\n",
"\n",
"Each recourse method in ReLax has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.\n",
"\n",
"For example, to configure VanillaCF:\n",
"\n",
"```Python\n",
"from relax.methods import VanillaCF \n",
"from relax.methods.vanilla import VanillaCFConfig\n",
"\n",
"config = VanillaCFConfig(\n",
" n_steps=100,\n",
" lr=0.1,\n",
" lambda_=0.1\n",
")\n",
"\n",
"vcf = VanillaCF(config)\n",
"\n",
"```\n",
"Each Config class inherits from a BaseConfig that defines common options like num_cfs. Method-specific parameters are defined on the individual Config classes.\n",
"\n",
"See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from relax.methods.[method_name]."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, we can also specify this config via a dictionary.\n",
"\n",
"```Python\n",
"from relax.methods import VanillaCF\n",
"\n",
"config = {\n",
" \"num_cfs\": 10, \n",
" \"epsilon\": 0.01,\n",
" \"lr\": 0.1 \n",
"}\n",
"\n",
"vcf = VanillaCF(config)\n",
"```\n",
"\n",
"This config dictionary is passed to VanillaCF's __init__ method, which will set the specified parameters. Now our VanillaCF instance is configured to:\n",
"\n",
" * Generate 10 counterfactuals per input (num_cfs=10)\n",
" * Use a maximum perturbation of 0.01 (epsilon=0.01)\n",
" * Use a learning rate of 0.1 for optimization (lr=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Modifying at Runtime:\n",
"\n",
"The configuration can also be updated after constructing the recourse method:\n",
"\n",
"```Python\n",
"vcf = VanillaCF() \n",
"\n",
"# Later, modify config\n",
"vcf.config[\"lr\"] = 0.5\n",
"\n",
"```\n",
"This allows dynamically adjusting the configuration as needed."
]
}
],
"metadata": {
Expand All @@ -45,5 +181,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}