Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gumbel softmax transformation #40

Merged
merged 2 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions nbs/00_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,33 @@
" return upt_params, opt_state"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Functional Utils"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def gumbel_softmax(\n",
" key: jrand.PRNGKey, # Random key\n",
" logits: Array, # Logits for each class. Shape (batch_size, num_classes)\n",
" tau: float, # Temperature for the Gumbel softmax\n",
" axis: int | tuple[int, ...] = -1, # The axis or axes along which the gumbel softmax should be computed\n",
"):\n",
" \"\"\"The Gumbel softmax function.\"\"\"\n",
"\n",
" gumbel_noise = jrand.gumbel(key, shape=logits.shape)\n",
" y = logits + gumbel_noise\n",
" return jax.nn.softmax(y / tau, axis=axis)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
153 changes: 107 additions & 46 deletions nbs/01_data.utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"outputs": [],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"from ipynb_path import *\n",
"import warnings\n",
"\n",
"warnings.simplefilter(action='ignore', category=FutureWarning)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from __future__ import annotations\n",
Expand All @@ -40,8 +47,10 @@
"import einops\n",
"import os, sys, json, pickle\n",
"import shutil\n",
"from relax.utils import *\n",
"import chex"
"from relax.utils import gumbel_softmax, load_pytree, save_pytree, get_config\n",
"import chex\n",
"import functools as ft\n",
"import warnings"
]
},
{
Expand Down Expand Up @@ -467,8 +476,8 @@
"transformed_xs = scaler.fit_transform(xs)\n",
"assert scaler.is_categorical is False\n",
"\n",
"cfs = np.random.randn(100, 1)\n",
"cf_constrained = scaler.apply_constraints(xs, cfs)\n",
"x = np.random.randn(100, 1)\n",
"cf_constrained = scaler.apply_constraints(xs, x)\n",
"assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)\n",
"\n",
"# Test from_dict and to_dict\n",
Expand All @@ -483,20 +492,27 @@
"outputs": [],
"source": [
"#| export\n",
"class OneHotTransformation(Transformation):\n",
"class _OneHotTransformation(Transformation):\n",
" def __init__(self):\n",
" super().__init__(\"ohe\", OneHotEncoder())\n",
"\n",
" @property\n",
" def num_categories(self) -> int:\n",
" return len(self.transformer.categories_)\n",
" \n",
" def hard_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]): \n",
" x, rng_key, kwargs = operand\n",
" return jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories)\n",
" \n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
" raise NotImplementedError\n",
"\n",
" def apply_constraints(self, xs, cfs, hard: bool = False, **kwargs):\n",
" def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):\n",
" return jax.lax.cond(\n",
" hard,\n",
" true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories),\n",
" false_fun=lambda x: jax.nn.softmax(x, axis=-1),\n",
" operand=cfs,\n",
" true_fun=self.hard_constraints,\n",
" false_fun=self.soft_constraints,\n",
" operand=(cfs, rng_key, kwargs),\n",
" )\n",
" \n",
" def compute_reg_loss(self, xs, cfs, hard: bool = False):\n",
Expand All @@ -510,31 +526,74 @@
"metadata": {},
"outputs": [],
"source": [
"xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))\n",
"ohe_t = OneHotTransformation().fit(xs)\n",
"transformed_xs = ohe_t.transform(xs)\n",
"#| export\n",
"class SoftmaxTransformation(_OneHotTransformation):\n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
" x, rng_key, kwargs = operand\n",
" return jax.nn.softmax(x, axis=-1)\n",
" \n",
"class GumbelSoftmaxTransformation(_OneHotTransformation):\n",
" \"\"\"Apply Gumbel softmax tricks for categorical transformation.\"\"\"\n",
"\n",
"cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 3))\n",
"# Test hard=True which applies softmax function.\n",
"soft = ohe_t.apply_constraints(transformed_xs, cfs, hard=False)\n",
"assert jnp.allclose(soft.sum(axis=-1), 1)\n",
"assert jnp.all(soft >= 0)\n",
"assert jnp.all(soft <= 1)\n",
"assert jnp.allclose(jnp.zeros((len(cfs), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))\n",
" def __init__(self, tau: float = 1.):\n",
" super().__init__()\n",
" self.tau = tau\n",
" \n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
" x, rng_key, _ = operand\n",
" if rng_key is None: # No randomness\n",
" rng_key = jax.random.PRNGKey(get_config().global_seed)\n",
" return gumbel_softmax(rng_key, x, self.tau)\n",
" \n",
" def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):\n",
" \"\"\"Apply constraints to the counterfactuals. If `rng_key` is None, no randomness is used.\"\"\"\n",
" return super().apply_constraints(xs, cfs, hard, rng_key, **kwargs)\n",
" \n",
"def OneHotTransformation():\n",
" warnings.warn(\"OneHotTransformation is deprecated since v0.2.5. \"\n",
" \"Use `SoftmaxTransformation` (same functionality) \"\n",
" \"or GumbelSoftmaxTransformation instead.\", DeprecationWarning)\n",
" return SoftmaxTransformation()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test_ohe_t(ohe_cls):\n",
" xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))\n",
" ohe_t = ohe_cls().fit(xs)\n",
" transformed_xs = ohe_t.transform(xs)\n",
" rng_key = jax.random.PRNGKey(get_config().global_seed)\n",
"\n",
"# Test hard=True which enforce one-hot constraint.\n",
"hard = ohe_t.apply_constraints(transformed_xs, cfs, hard=True)\n",
"assert np.all([1 in x for x in hard])\n",
"assert np.all([0 in x for x in hard])\n",
"assert jnp.allclose(hard.sum(axis=-1), 1)\n",
"assert jnp.allclose(jnp.zeros((len(cfs), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))\n",
" x = jax.random.uniform(rng_key, shape=(100, 3))\n",
" # Test hard=True which applies softmax function.\n",
" soft = ohe_t.apply_constraints(transformed_xs, x, hard=False, rng_key=rng_key)\n",
" assert jnp.allclose(soft.sum(axis=-1), 1)\n",
" assert jnp.all(soft >= 0)\n",
" assert jnp.all(soft <= 1)\n",
" assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))\n",
" assert jnp.allclose(soft, ohe_t.apply_constraints(transformed_xs, x, hard=False))\n",
"\n",
"# Test compute_reg_loss\n",
"assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0\n",
" # Test hard=True which enforce one-hot constraint.\n",
" hard = ohe_t.apply_constraints(transformed_xs, x, hard=True, rng_key=rng_key)\n",
" assert np.all([1 in x for x in hard])\n",
" assert np.all([0 in x for x in hard])\n",
" assert jnp.allclose(hard.sum(axis=-1), 1)\n",
" assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))\n",
"\n",
"# Test from_dict and to_dict\n",
"ohe_t_1 = OneHotTransformation().from_dict(ohe_t.to_dict())\n",
"assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))"
" # Test compute_reg_loss\n",
" assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0\n",
"\n",
" # Test from_dict and to_dict\n",
" ohe_t_1 = ohe_cls().from_dict(ohe_t.to_dict())\n",
" assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))\n",
"\n",
"\n",
"test_ohe_t(SoftmaxTransformation)\n",
"test_ohe_t(GumbelSoftmaxTransformation)"
]
},
{
Expand Down Expand Up @@ -616,7 +675,9 @@
"source": [
"#| export\n",
"PREPROCESSING_TRANSFORMATIONS = {\n",
" 'ohe': OneHotTransformation,\n",
" 'ohe': SoftmaxTransformation,\n",
" 'softmax': SoftmaxTransformation,\n",
" 'gumbel': GumbelSoftmaxTransformation,\n",
" 'minmax': MinMaxTransformation,\n",
" 'ordinal': OrdinalTransformation,\n",
" 'identity': IdentityTransformation,\n",
Expand Down Expand Up @@ -1157,8 +1218,8 @@
" assert feat.is_categorical is False\n",
" assert feat.is_immutable is False\n",
"\n",
"cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n",
"feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], cfs, hard=True)"
"x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n",
"feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)"
]
},
{
Expand Down Expand Up @@ -1189,17 +1250,17 @@
"outputs": [],
"source": [
"# Test apply_constraints and compute_reg_loss\n",
"cfs = np.random.randn(10, 29)\n",
"constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=False)\n",
"x = np.random.randn(10, 29)\n",
"constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=False)\n",
"assert constraint_cfs.shape == (10, 29)\n",
"assert np.allclose(\n",
" constraint_cfs[:, 2:].sum(axis=-1),\n",
" np.ones((10,)) * 6\n",
")\n",
"assert constraint_cfs[: :2].min() >= 0 and constraint_cfs[: :2].max() <= 1\n",
"assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=True).shape == (10, 29)\n",
"assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=True).shape == (10, 29)\n",
"\n",
"reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, cfs)\n",
"reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, x)\n",
"assert jnp.ndim(reg_loss) == 0\n",
"assert np.all(reg_loss > 0)\n",
"assert np.allclose(feats_list.compute_reg_loss(xs, constraint_cfs), 0)"
Expand Down
35 changes: 27 additions & 8 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@
'relax/data_utils.py'),
'relax.data_utils.FeaturesList.with_transformed_data': ( 'data.utils.html#featureslist.with_transformed_data',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation': ( 'data.utils.html#gumbelsoftmaxtransformation',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation.__init__': ( 'data.utils.html#gumbelsoftmaxtransformation.__init__',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation.apply_constraints': ( 'data.utils.html#gumbelsoftmaxtransformation.apply_constraints',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation.soft_constraints': ( 'data.utils.html#gumbelsoftmaxtransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils.IdentityTransformation': ( 'data.utils.html#identitytransformation',
'relax/data_utils.py'),
'relax.data_utils.IdentityTransformation.__init__': ( 'data.utils.html#identitytransformation.__init__',
Expand Down Expand Up @@ -221,14 +229,6 @@
'relax.data_utils.OneHotEncoder.transform': ( 'data.utils.html#onehotencoder.transform',
'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation': ('data.utils.html#onehottransformation', 'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation.__init__': ( 'data.utils.html#onehottransformation.__init__',
'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation.apply_constraints': ( 'data.utils.html#onehottransformation.apply_constraints',
'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation.compute_reg_loss': ( 'data.utils.html#onehottransformation.compute_reg_loss',
'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation.num_categories': ( 'data.utils.html#onehottransformation.num_categories',
'relax/data_utils.py'),
'relax.data_utils.OrdinalPreprocessor': ('data.utils.html#ordinalpreprocessor', 'relax/data_utils.py'),
'relax.data_utils.OrdinalPreprocessor.fit': ( 'data.utils.html#ordinalpreprocessor.fit',
'relax/data_utils.py'),
Expand All @@ -242,6 +242,10 @@
'relax/data_utils.py'),
'relax.data_utils.OrdinalTransformation.num_categories': ( 'data.utils.html#ordinaltransformation.num_categories',
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation': ( 'data.utils.html#softmaxtransformation',
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation.soft_constraints': ( 'data.utils.html#softmaxtransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils.Transformation': ('data.utils.html#transformation', 'relax/data_utils.py'),
'relax.data_utils.Transformation.__init__': ( 'data.utils.html#transformation.__init__',
'relax/data_utils.py'),
Expand All @@ -262,6 +266,20 @@
'relax/data_utils.py'),
'relax.data_utils.Transformation.transform': ( 'data.utils.html#transformation.transform',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation': ( 'data.utils.html#_onehottransformation',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.__init__': ( 'data.utils.html#_onehottransformation.__init__',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.apply_constraints': ( 'data.utils.html#_onehottransformation.apply_constraints',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.compute_reg_loss': ( 'data.utils.html#_onehottransformation.compute_reg_loss',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.hard_constraints': ( 'data.utils.html#_onehottransformation.hard_constraints',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.num_categories': ( 'data.utils.html#_onehottransformation.num_categories',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.soft_constraints': ( 'data.utils.html#_onehottransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils._check_xs': ('data.utils.html#_check_xs', 'relax/data_utils.py'),
'relax.data_utils._unique': ('data.utils.html#_unique', 'relax/data_utils.py')},
'relax.docs': { 'relax.docs.CustomizedMarkdownRenderer': ('docs.html#customizedmarkdownrenderer', 'relax/docs.py'),
Expand Down Expand Up @@ -797,6 +815,7 @@
'relax.utils.auto_reshaping': ('utils.html#auto_reshaping', 'relax/utils.py'),
'relax.utils.get_config': ('utils.html#get_config', 'relax/utils.py'),
'relax.utils.grad_update': ('utils.html#grad_update', 'relax/utils.py'),
'relax.utils.gumbel_softmax': ('utils.html#gumbel_softmax', 'relax/utils.py'),
'relax.utils.load_json': ('utils.html#load_json', 'relax/utils.py'),
'relax.utils.load_pytree': ('utils.html#load_pytree', 'relax/utils.py'),
'relax.utils.save_pytree': ('utils.html#save_pytree', 'relax/utils.py'),
Expand Down
Loading
Loading