Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 72 additions & 37 deletions nbs/01_data.module.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
"from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder\n",
"from sklearn.base import TransformerMixin\n",
"from urllib.request import urlretrieve\n",
"from relax.data.loader import Dataset, ArrayDataset, DataLoader, DataloaderBackends"
"from relax.data.loader import Dataset, ArrayDataset, DataLoader, DataloaderBackends\n",
"from copy import deepcopy"
]
},
{
Expand Down Expand Up @@ -213,13 +214,23 @@
"def _process_data(\n",
" df: pd.DataFrame | None, configs: TabularDataModuleConfigs\n",
") -> pd.DataFrame:\n",
" \"\"\"\n",
" This function does the following:\n",
" * Check and load data.\n",
" * Select first `sample_frac` of the data.\n",
" * Check and load only specified columns.\n",
" \"\"\"\n",
" if df is None:\n",
" df = pd.read_csv(configs.data_dir)\n",
" elif isinstance(df, pd.DataFrame):\n",
" df = df\n",
" else:\n",
" raise ValueError(f\"{type(df).__name__} is not supported as an input type for `TabularDataModule`.\")\n",
"\n",
" if configs.sample_frac is not None:\n",
" sample_size = int(len(df) * configs.sample_frac)\n",
" df = df.iloc[:sample_size]\n",
" \n",
" df = _check_cols(df, configs)\n",
" return df"
]
Expand Down Expand Up @@ -401,7 +412,7 @@
" ge=0., le=1.0\n",
" )\n",
" backend: str = Field(\n",
" \"jax\", description=f\"`Dataloader` backend. Currently supports: {_supported_backends()}\"\n",
" \"jax\", description=f\"`Dataloader` backend. Currently supports: {DataloaderBackends.supported()}\"\n",
" )\n"
]
},
Expand All @@ -415,7 +426,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L187){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L196){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModuleConfigs\n",
"\n",
Expand Down Expand Up @@ -445,7 +456,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L187){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L196){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModuleConfigs\n",
"\n",
Expand Down Expand Up @@ -504,7 +515,7 @@
" imutable_cols=[\"age\", \"workclass\", \"marital_status\"],\n",
" normalizer=None, \n",
" encoder=None, \n",
" sample_frac=0.1,\n",
" sample_frac=None,\n",
" backend='jax'\n",
")"
]
Expand Down Expand Up @@ -555,9 +566,9 @@
" train_X, test_X, train_y, test_y = map(\n",
" lambda x: x.astype(float), train_test_tuple\n",
" )\n",
" if self._configs.sample_frac:\n",
" train_size = int(len(train_X) * self._configs.sample_frac)\n",
" train_X, train_y = train_X[:train_size], train_y[:train_size]\n",
" # if self._configs.sample_frac:\n",
" # train_size = int(len(train_X) * self._configs.sample_frac)\n",
" # train_X, train_y = train_X[:train_size], train_y[:train_size]\n",
" \n",
" self._train_dataset = ArrayDataset(train_X, train_y)\n",
" self._val_dataset = ArrayDataset(test_X, test_y)\n",
Expand Down Expand Up @@ -719,6 +730,31 @@
"assert len(dm.data) == 1000 # dm contains `df`"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"If `TabularDataModuleConfigs.sample_frac` is set to a float value,\n",
"internally, the `TabularDataModule` will load the first `sample_frac` \n",
"of data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"configs01 = deepcopy(configs)\n",
"configs01.sample_frac = 0.1\n",
"dm = TabularDataModule(configs01, data=df)\n",
"assert len(dm.data) == 100 # dm contains 10% of `df`\n",
"assert dm.data[configs.discret_cols].equals(\n",
" df[:100][configs.discret_cols]\n",
") # dm contains the same 10% of `df` data"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -729,7 +765,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L268){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L277){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.data\n",
"\n",
Expand All @@ -740,7 +776,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L268){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L277){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.data\n",
"\n",
Expand Down Expand Up @@ -901,7 +937,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L300){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L317){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.transform\n",
"\n",
Expand All @@ -917,7 +953,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L300){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L317){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.transform\n",
"\n",
Expand Down Expand Up @@ -978,7 +1014,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L317){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L334){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.inverse_transform\n",
"\n",
Expand All @@ -998,7 +1034,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L317){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L334){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.inverse_transform\n",
"\n",
Expand Down Expand Up @@ -1276,7 +1312,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L338){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L355){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.apply_constraints\n",
"\n",
Expand All @@ -1297,7 +1333,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L338){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L355){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.apply_constraints\n",
"\n",
Expand Down Expand Up @@ -1340,15 +1376,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"outputs": [],
"source": [
"x, y = next(iter(dm.test_dataloader(batch_size=128)))\n",
"# unnormalized counterfactuals\n",
Expand All @@ -1369,7 +1397,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L357){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L374){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.apply_regularization\n",
"\n",
Expand All @@ -1389,7 +1417,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/data/module.py#L357){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L374){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TabularDataModule.apply_regularization\n",
"\n",
Expand Down Expand Up @@ -1479,6 +1507,10 @@
" x, y = next(iter(dl))\n",
" assert x.shape[0] == batch_size\n",
"\n",
" assert len(dm.data) == len(dm.train_dataset) + len(dm.val_dataset)\n",
" sample_frac = 1.0 if data_configs.sample_frac is None else data_configs.sample_frac\n",
" assert len(dm.data) == int(len(pd.read_csv(data_configs.data_dir)) * sample_frac)\n",
"\n",
" ############################################################\n",
" # test `transform` and `inverse_transform`\n",
" ############################################################\n",
Expand All @@ -1505,16 +1537,6 @@
" assert jnp.count_nonzero(cf == 1) == len(cf) * n_cat_feat\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from copy import deepcopy"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -1535,6 +1557,19 @@
"dm = TabularDataModule(data_configs)\n",
"check_datamodule(dm, data_configs)\n",
"\n",
"# sample_frac=None\n",
"_data_configs = deepcopy(data_configs)\n",
"_data_configs[\"sample_frac\"] = None\n",
"dm = TabularDataModule(data_configs)\n",
"check_datamodule(dm, data_configs)\n",
"\n",
"# sample_frac=0\n",
"_data_configs = deepcopy(data_configs)\n",
"_data_configs[\"sample_frac\"] = 0.\n",
"dm = TabularDataModule(data_configs)\n",
"check_datamodule(dm, data_configs)\n",
"\n",
"\n",
"# immutable\n",
"_data_configs = deepcopy(data_configs)\n",
"_data_configs[\"imutable_cols\"] = [\"race\",\"gender\"]\n",
Expand Down
27 changes: 19 additions & 8 deletions relax/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.base import TransformerMixin
from urllib.request import urlretrieve
from .loader import Dataset, ArrayDataset, DataLoader, DataloaderBackends
from copy import deepcopy

# %% auto 0
__all__ = ['BaseDataModule', 'find_imutable_idx_list', 'TabularDataModuleConfigs', 'TabularDataModule', 'sample', 'load_data']
Expand Down Expand Up @@ -127,13 +128,23 @@ def _check_cols(data: pd.DataFrame, configs: TabularDataModuleConfigs) -> pd.Dat
def _process_data(
df: pd.DataFrame | None, configs: TabularDataModuleConfigs
) -> pd.DataFrame:
"""
This function does the following:
* Check and load data.
* Select first `sample_frac` of the data.
* Check and load only specified columns.
"""
if df is None:
df = pd.read_csv(configs.data_dir)
elif isinstance(df, pd.DataFrame):
df = df
else:
raise ValueError(f"{type(df).__name__} is not supported as an input type for `TabularDataModule`.")

if configs.sample_frac is not None:
sample_size = int(len(df) * configs.sample_frac)
df = df.iloc[:sample_size]

df = _check_cols(df, configs)
return df

Expand Down Expand Up @@ -216,7 +227,7 @@ class TabularDataModuleConfigs(BaseParser):
ge=0., le=1.0
)
backend: str = Field(
"jax", description=f"`Dataloader` backend. Currently supports: {_supported_backends()}"
"jax", description=f"`Dataloader` backend. Currently supports: {DataloaderBackends.supported()}"
)


Expand Down Expand Up @@ -260,9 +271,9 @@ def prepare_data(self):
train_X, test_X, train_y, test_y = map(
lambda x: x.astype(float), train_test_tuple
)
if self._configs.sample_frac:
train_size = int(len(train_X) * self._configs.sample_frac)
train_X, train_y = train_X[:train_size], train_y[:train_size]
# if self._configs.sample_frac:
# train_size = int(len(train_X) * self._configs.sample_frac)
# train_X, train_y = train_X[:train_size], train_y[:train_size]

self._train_dataset = ArrayDataset(train_X, train_y)
self._val_dataset = ArrayDataset(test_X, test_y)
Expand Down Expand Up @@ -387,13 +398,13 @@ def apply_regularization(
return reg_loss


# %% ../../nbs/01_data.module.ipynb 41
# %% ../../nbs/01_data.module.ipynb 43
def sample(datamodule: BaseDataModule, frac: float = 1.0):
X, y = datamodule.train_dataset[:]
size = int(len(X) * frac)
return X[:size], y[:size]

# %% ../../nbs/01_data.module.ipynb 46
# %% ../../nbs/01_data.module.ipynb 47
DEFAULT_DATA_CONFIGS = {
'adult': {
'data' :'assets/data/s_adult.csv',
Expand All @@ -409,13 +420,13 @@ def sample(datamodule: BaseDataModule, frac: float = 1.0):
}
}

# %% ../../nbs/01_data.module.ipynb 47
# %% ../../nbs/01_data.module.ipynb 48
def _validate_dataname(data_name: str):
if data_name not in DEFAULT_DATA_CONFIGS.keys():
raise ValueError(f'`data_name` must be one of {DEFAULT_DATA_CONFIGS.keys()}, '
f'but got data_name={data_name}.')

# %% ../../nbs/01_data.module.ipynb 48
# %% ../../nbs/01_data.module.ipynb 49
def load_data(
data_name: str, # The name of data
return_config: bool = False, # Return `data_config `or not
Expand Down