Skip to content

Commit

Permalink
Merge pull request #27 from BirkhoffG/fix-gs
Browse files Browse the repository at this point in the history
Fix performance issue in GrowingSphere
  • Loading branch information
BirkhoffG authored Nov 20, 2023
2 parents 02a72a7 + 9458dd2 commit d4b70e4
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 99 deletions.
5 changes: 5 additions & 0 deletions benchmarks/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def main(args):

# load data and data configs
dm = relax.load_data(data_name)

keras.mixed_precision.set_global_policy("mixed_float16")

# load predict function
ml_model = relax.load_ml_module(data_name)
Expand Down Expand Up @@ -130,5 +132,8 @@ def main(args):

if args.disable_jit:
jax.config.update("jax_disable_jit", True)

jax.profiler.start_trace("/tmp/tensorboard")
main(args)
jax.profiler.stop_trace()

3 changes: 0 additions & 3 deletions nbs/03_explain.strategy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
" **kwargs\n",
" ) -> Array: # Generated counterfactual explanations\n",
" \n",
" @jit\n",
" def partial_fn(x, y_target, rng_key):\n",
" return fn(x, pred_fn=pred_fn, y_target=y_target, rng_key=rng_key, **kwargs)\n",
" \n",
Expand Down Expand Up @@ -190,7 +189,6 @@
" **kwargs\n",
" ) -> Array: # Generated counterfactual explanations\n",
" \n",
" @jit\n",
" def partial_fn(x, y_target, rng_key, **kwargs):\n",
" return fn(x, pred_fn=pred_fn, y_target=y_target, rng_key=rng_key, **kwargs)\n",
"\n",
Expand Down Expand Up @@ -234,7 +232,6 @@
") -> Array: # Generated counterfactual explanations\n",
" \"\"\"Batched of counterfactuals.\"\"\"\n",
"\n",
" @jit\n",
" def gs_fn_partial(state):\n",
" x, y_target, rng_key = state\n",
" return gs_fn(cf_fn, x, pred_fn, y_target, rng_key, **kwargs)\n",
Expand Down
138 changes: 81 additions & 57 deletions nbs/methods/05_sphere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"#| include: false\n",
"%load_ext autoreload\n",
Expand All @@ -37,27 +46,12 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using JAX backend.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from __future__ import annotations\n",
"from relax.import_essentials import *\n",
"from relax.methods.base import CFModule, BaseConfig\n",
"from relax.methods.base import CFModule, BaseConfig, default_apply_constraints_fn\n",
"from relax.utils import auto_reshaping, grad_update, validate_configs\n",
"from relax.data_utils import Feature, FeaturesList\n",
"from relax.data_module import DataModule"
Expand Down Expand Up @@ -110,7 +104,6 @@
"outputs": [],
"source": [
"#| export\n",
"@partial(jit, static_argnums=(1, 2))\n",
"def sample_categorical(rng_key: jrand.PRNGKey, col_size: int, n_samples: int):\n",
" rng_key, _ = jrand.split(rng_key)\n",
" prob = jnp.ones(col_size) / col_size\n",
Expand Down Expand Up @@ -180,32 +173,42 @@
"outputs": [],
"source": [
"#| export\n",
"@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7))\n",
"@partial(jit, static_argnums=(2, 5, 8, 9))\n",
"def perturb_function_with_features(\n",
" rng_key: jrand.PRNGKey,\n",
" x: np.ndarray, # Shape: (1, k)\n",
" n_samples: int,\n",
" high, \n",
" low,\n",
" p_norm,\n",
" feats_info: List[Tuple[int, int, int]], # [(start, end, num_categories)]\n",
" high: float, \n",
" low: float,\n",
" p_norm: int,\n",
" cont_masks: Array,\n",
" immut_masks: Array,\n",
" num_categories: list[int],\n",
" cat_perturb_fn: Callable\n",
"):\n",
" def perturb_feature(rng_key, x_sliced, num_categories):\n",
" if num_categories > 0:\n",
" return cat_perturb_fn(rng_key, num_categories, n_samples)\n",
" else: \n",
" return hyper_sphere_coordindates(\n",
" rng_key, x_sliced, n_samples, high, low, p_norm\n",
" )\n",
" \n",
" rng_keys = jrand.split(rng_key, len(feats_info))\n",
" perturbed = jnp.repeat(x, n_samples, axis=0)\n",
" for rng_key, (start, end, num_categories) in zip(rng_keys, feats_info):\n",
" x_sliced = lax.dynamic_slice(x, (0, start), (1, end - start))\n",
" _perturbed_feat = perturb_feature(rng_key, x_sliced, num_categories)\n",
" perturbed = perturbed.at[:, start: end].set(_perturbed_feat)\n",
" return perturbed\n"
" def perturb_cat_feat(rng_key, num_categories):\n",
" rng_key, next_key = jrand.split(rng_key)\n",
" sampled = cat_perturb_fn(rng_key, num_categories, n_samples)\n",
" return next_key, sampled\n",
" \n",
" # cont_masks, immut_masks, num_categories = feats_info\n",
" key_1, key_2 = jrand.split(rng_key)\n",
" perturbed_cont = cont_masks * hyper_sphere_coordindates(\n",
" key_1, x, n_samples, high, low, p_norm\n",
" )\n",
" cat_masks = jnp.where(cont_masks, 0, 1)\n",
" perturbed_cat = cat_masks * jnp.concatenate([\n",
" perturb_cat_feat(key_2, num_cat)[1] for num_cat in num_categories\n",
" ], axis=1)\n",
"\n",
" perturbed = jnp.where(\n",
" immut_masks,\n",
" jnp.repeat(x, n_samples, axis=0),\n",
" perturbed_cont + perturbed_cat\n",
" )\n",
" \n",
" return perturbed"
]
},
{
Expand All @@ -215,17 +218,31 @@
"outputs": [],
"source": [
"#| exporti\n",
"def features_to_infos_and_perturb_fn(features: FeaturesList):\n",
" feats_info = []\n",
"def features_to_infos_and_perturb_fn(\n",
" features: FeaturesList\n",
") -> Tuple[List[Array,Array,Array,Array,Array], Callable]:\n",
" cont_masks = []\n",
" immut_masks = []\n",
" n_categories = []\n",
" cat_transformation_name = None\n",
" for (start, end), feat in zip(features.feature_indices, features):\n",
" if feat.is_categorical:\n",
" feat_info = (start, end, feat.transformation.num_categories)\n",
" cont_mask = jnp.zeros(feat.transformation.num_categories)\n",
" immut_mask = cont_mask * np.array([feat.is_immutable], dtype=np.int32)\n",
" n_categorie = feat.transformation.num_categories\n",
" cat_transformation_name = feat.transformation.name\n",
" else:\n",
" feat_info = (start, end, -1)\n",
" feats_info.append(feat_info)\n",
" return tuple(feats_info), cat_perturb_fn(cat_transformation_name)\n",
" cont_mask = jnp.ones(1)\n",
" immut_mask = cont_mask * np.array([feat.is_immutable], dtype=np.int32)\n",
" n_categorie = 1\n",
" \n",
" cont_masks, immut_masks, n_categories = map(lambda x, y: x + [y], \n",
" [cont_masks, immut_masks, n_categories],\n",
" [cont_mask, immut_mask, n_categorie]\n",
" )\n",
" \n",
" cont_masks, immut_masks = map(lambda x: jnp.concatenate(x, axis=0), [cont_masks, immut_masks])\n",
" return (cont_masks, immut_masks, tuple(n_categories)), cat_perturb_fn(cat_transformation_name)\n",
"\n",
"def cat_perturb_fn(transformation):\n",
" def ohe_perturb_fn(rng_key, num_categories, n_samples):\n",
Expand Down Expand Up @@ -255,7 +272,7 @@
"feats_info, perturb_fn = features_to_infos_and_perturb_fn(dm.features)\n",
"assert x_sliced.ndim == 2\n",
"cfs = perturb_function_with_features(\n",
" jrand.PRNGKey(0), x_sliced, 1000, 1, 0, 2, feats_info, perturb_fn\n",
" jrand.PRNGKey(0), x_sliced, 1000, 1, 0, 2, *feats_info, perturb_fn\n",
")\n",
"assert cfs.shape == (1000, 29)\n",
"assert cfs[:, 2:].sum() == 1000 * 6\n",
Expand All @@ -282,7 +299,8 @@
" step_size: float, # Step size\n",
" p_norm: int, # Norm\n",
" perturb_fn: Callable, # Perturbation function\n",
" apply_constraints_fn: Callable # Apply immutable constraints\n",
" apply_constraints_fn: Callable, # Apply immutable constraints\n",
" dtype: jnp.dtype = jnp.float32, # Data type\n",
"): \n",
" @jit\n",
" def dist_fn(x, cf):\n",
Expand All @@ -296,7 +314,7 @@
" @loop_tqdm(n_steps)\n",
" def step(i, state):\n",
" candidate_cf, count, rng_key = state\n",
" rng_key, subkey_1, subkey_2 = jrand.split(rng_key, num=3)\n",
" rng_key, subkey = jrand.split(rng_key)\n",
" low, high = step_size * count, step_size * (count + 1)\n",
" # Sample around x\n",
" candidates = perturb_fn(rng_key, x, n_samples, high=high, low=low, p_norm=p_norm)\n",
Expand All @@ -309,12 +327,14 @@
" dist = dist_fn(x, candidates)\n",
"\n",
" # Calculate counterfactual labels\n",
" candidate_preds = pred_fn(candidates).argmax(axis=1)\n",
" indices = jnp.where(candidate_preds == y_target, 1, 0).astype(bool)\n",
" candidate_preds = pred_fn(candidates).argmax(axis=1, keepdims=True)\n",
" indices = candidate_preds == y_target\n",
"\n",
" candidates = jnp.where(indices.reshape(-1, 1), \n",
" candidates, jnp.ones_like(candidates) * jnp.inf)\n",
" dist = jnp.where(indices.reshape(-1, 1), dist, jnp.ones_like(dist) * jnp.inf)\n",
" # Select valid candidates and their distances\n",
" candidates, dist = jax.tree_util.tree_map(\n",
" lambda x: jnp.where(indices, x, jnp.ones_like(x) * jnp.inf), \n",
" (candidates, dist)\n",
" )\n",
"\n",
" closest_idx = dist.argmin()\n",
" candidate_cf_update = candidates[closest_idx].reshape(1, -1)\n",
Expand All @@ -324,7 +344,7 @@
" candidate_cf_update, \n",
" candidate_cf\n",
" )\n",
" return candidate_cf, count + 1, rng_key\n",
" return candidate_cf, count + 1, subkey\n",
" \n",
" y_target = y_target.reshape(1, -1).argmax(axis=1)\n",
" candidate_cf = jnp.ones_like(x) * jnp.inf\n",
Expand Down Expand Up @@ -387,11 +407,15 @@
" if self.perturb_fn is None:\n",
" if self.has_data_module():\n",
" feats_info, perturb_fn = features_to_infos_and_perturb_fn(self.data_module.features)\n",
" cont_masks, immut_masks, num_categories = feats_info\n",
" self.perturb_fn = ft.partial(\n",
" perturb_function_with_features, \n",
" feats_info=feats_info,\n",
" cont_masks=cont_masks,\n",
" immut_masks=immut_masks,\n",
" num_categories=num_categories,\n",
" cat_perturb_fn=perturb_fn\n",
" )\n",
" self.apply_constraints = default_apply_constraints_fn\n",
" else:\n",
" self.perturb_fn = default_perturb_function\n",
" \n",
Expand Down Expand Up @@ -447,7 +471,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "54ebe22c890f47e5b50654e5a9461831",
"model_id": "cc5fefb35c7f48308e4d2853687c0178",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -478,7 +502,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "900d02459f67488293174606b72b25a9",
"model_id": "6bc22bdc99d44ef9940b7b3df1887b37",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -514,7 +538,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f3c3eac1cfaa458c9feedba3d0f98783",
"model_id": "e6cff5dd3fe34fb0ab66def5f5368e85",
"version_major": 2,
"version_minor": 0
},
Expand Down
Loading

0 comments on commit d4b70e4

Please sign in to comment.