From 03034e83e92f78df19f2561fb4d4cc57ce1d3aa7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 22:14:33 +0000 Subject: [PATCH 1/2] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v5.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.6.0...v5.0.0) - [github.com/astral-sh/ruff-pre-commit: v0.5.6 → v0.8.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.5.6...v0.8.3) --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ff00d88..a4f1d12 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-builtin-literals @@ -12,7 +12,7 @@ repos: - id: end-of-file-fixer - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.6 + rev: v0.8.3 hooks: - id: ruff args: From 858caf293ba06c0090590f7e2323670b37e326bf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 22:14:40 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- notebooks/airbnb.ipynb | 105 ++++++++++++++----------- notebooks/airbnb_tf.ipynb | 98 ++++++++++++++---------- notebooks/bank.ipynb | 140 ++++++++++++++++++++++------------ notebooks/bike.ipynb | 29 ++++--- notebooks/calibration.ipynb | 124 +++++++++++++++++++++--------- notebooks/consistency.ipynb | 55 +++++++------ notebooks/credit.ipynb | 127 +++++++++++++++++------------- notebooks/losses.ipynb | 11 +-- notebooks/mnist.ipynb | 62 +++++++-------- sage/iterated_estimator.py | 4 +- sage/kernel_estimator.py | 4 +- sage/permutation_estimator.py | 2 +- sage/utils.py | 4 +- 13 files changed, 458 insertions(+), 307 deletions(-) diff --git a/notebooks/airbnb.ipynb b/notebooks/airbnb.ipynb index a62741b..2378cd9 100644 --- a/notebooks/airbnb.ipynb +++ b/notebooks/airbnb.ipynb @@ -16,11 +16,13 @@ "outputs": [], "source": [ "import re\n", - "import sage\n", + "\n", + "import gender_guesser.detector as detector\n", "import numpy as np\n", "import pandas as pd\n", - "import gender_guesser.detector as detector\n", - "from sklearn.model_selection import train_test_split" + "from sklearn.model_selection import train_test_split\n", + "\n", + "import sage" ] }, { @@ -217,7 +219,7 @@ "outputs": [], "source": [ "# Categorical features\n", - "categorical_columns = ['neighbourhood_group', 'neighbourhood', 'room_type']\n", + "categorical_columns = [\"neighbourhood_group\", \"neighbourhood\", \"room_type\"]\n", "for column in categorical_columns:\n", " df[column] = pd.Categorical(df[column]).codes" ] @@ -229,7 +231,7 @@ "outputs": [], "source": [ "# Exclude outliers (top 0.5%)\n", - "df = df[df['price'] < df['price'].quantile(0.995)]" + "df = df[df[\"price\"] < df[\"price\"].quantile(0.995)]" ] }, { @@ -239,9 +241,9 @@ "outputs": [], "source": [ "# Features derived from name\n", - "df['name_length'] = df['name'].apply(lambda x: len(x))\n", - "df['name_isupper'] = df['name'].apply(lambda x: int(x.isupper()))\n", - "df['name_words'] = df['name'].apply(lambda x: len(re.findall(r'\\w+', x)))" + "df[\"name_length\"] = df[\"name\"].apply(lambda x: len(x))\n", + "df[\"name_isupper\"] = df[\"name\"].apply(lambda x: int(x.isupper()))\n", + "df[\"name_words\"] = df[\"name\"].apply(lambda x: len(re.findall(r\"\\w+\", x)))" ] }, { @@ -252,8 +254,8 @@ "source": [ "# Host gender guess\n", "guesser = detector.Detector()\n", - "df['host_gender'] = df['host_name'].apply(lambda x: guesser.get_gender(x.split(' ')[0]))\n", - "df['host_gender'] = pd.Categorical(df['host_gender']).codes" + "df[\"host_gender\"] = df[\"host_name\"].apply(lambda x: guesser.get_gender(x.split(\" \")[0]))\n", + "df[\"host_gender\"] = pd.Categorical(df[\"host_gender\"]).codes" ] }, { @@ -263,10 +265,12 @@ "outputs": [], "source": [ "# Number of days since last review\n", - "most_recent = df['last_review'].max()\n", - "df['last_review'] = (most_recent - df['last_review']).dt.days\n", - "df['last_review'] = (df['last_review'] - df['last_review'].mean()) / df['last_review'].std()\n", - "df['last_review'] = df['last_review'].fillna(-5)" + "most_recent = df[\"last_review\"].max()\n", + "df[\"last_review\"] = (most_recent - df[\"last_review\"]).dt.days\n", + "df[\"last_review\"] = (df[\"last_review\"] - df[\"last_review\"].mean()) / df[\n", + " \"last_review\"\n", + "].std()\n", + "df[\"last_review\"] = df[\"last_review\"].fillna(-5)" ] }, { @@ -276,7 +280,7 @@ "outputs": [], "source": [ "# Missing values\n", - "df['reviews_per_month'] = df['reviews_per_month'].fillna(0)" + "df[\"reviews_per_month\"] = df[\"reviews_per_month\"].fillna(0)" ] }, { @@ -286,9 +290,15 @@ "outputs": [], "source": [ "# Normalize other numerical features\n", - "df['number_of_reviews'] = (df['number_of_reviews'] - df['number_of_reviews'].mean()) / df['number_of_reviews'].std()\n", - "df['availability_365'] = (df['availability_365'] - df['availability_365'].mean()) / df['availability_365'].std()\n", - "df['name_length'] = (df['name_length'] - df['name_length'].mean()) / df['name_length'].std()" + "df[\"number_of_reviews\"] = (\n", + " df[\"number_of_reviews\"] - df[\"number_of_reviews\"].mean()\n", + ") / df[\"number_of_reviews\"].std()\n", + "df[\"availability_365\"] = (df[\"availability_365\"] - df[\"availability_365\"].mean()) / df[\n", + " \"availability_365\"\n", + "].std()\n", + "df[\"name_length\"] = (df[\"name_length\"] - df[\"name_length\"].mean()) / df[\n", + " \"name_length\"\n", + "].std()" ] }, { @@ -298,8 +308,8 @@ "outputs": [], "source": [ "# Normalize latitude and longitude\n", - "df['latitude'] = (df['latitude'] - df['latitude'].mean()) / df['latitude'].std()\n", - "df['longitude'] = (df['longitude'] - df['longitude'].mean()) / df['longitude'].std()" + "df[\"latitude\"] = (df[\"latitude\"] - df[\"latitude\"].mean()) / df[\"latitude\"].std()\n", + "df[\"longitude\"] = (df[\"longitude\"] - df[\"longitude\"].mean()) / df[\"longitude\"].std()" ] }, { @@ -309,7 +319,7 @@ "outputs": [], "source": [ "# Drop columns\n", - "df = df.drop(['id', 'host_id', 'host_name', 'name'], axis=1)" + "df = df.drop([\"id\", \"host_id\", \"host_name\", \"name\"], axis=1)" ] }, { @@ -503,7 +513,7 @@ "outputs": [], "source": [ "# Rearrange columns\n", - "target_col = 'price'\n", + "target_col = \"price\"\n", "cols = df.columns.tolist()\n", "del cols[cols.index(target_col)]\n", "cols.append(target_col)\n", @@ -512,9 +522,11 @@ "\n", "# Split data\n", "train, test = train_test_split(\n", - " df.values, test_size=int(0.1 * len(df.values)), random_state=0)\n", + " df.values, test_size=int(0.1 * len(df.values)), random_state=0\n", + ")\n", "train, val = train_test_split(\n", - " train, test_size=int(0.1 * len(df.values)), random_state=0)\n", + " train, test_size=int(0.1 * len(df.values)), random_state=0\n", + ")\n", "Y_train = train[:, -1:].copy()\n", "Y_val = val[:, -1:].copy()\n", "Y_test = test[:, -1:].copy()\n", @@ -536,11 +548,12 @@ "metadata": {}, "outputs": [], "source": [ + "from copy import deepcopy\n", + "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", - "from copy import deepcopy\n", - "from torch.utils.data import TensorDataset, DataLoader" + "from torch.utils.data import DataLoader, TensorDataset" ] }, { @@ -552,13 +565,14 @@ "outputs": [], "source": [ "# Create model\n", - "device = torch.device('cuda')\n", + "device = torch.device(\"cuda\")\n", "model = nn.Sequential(\n", " nn.Linear(len(feature_names), 512),\n", " nn.ELU(),\n", " nn.Linear(512, 512),\n", " nn.ELU(),\n", - " nn.Linear(512, 1)).to(device)\n", + " nn.Linear(512, 1),\n", + ").to(device)\n", "\n", "# Training parameters\n", "lr = 1e-3\n", @@ -570,8 +584,8 @@ "\n", "# Data loaders\n", "train_set = TensorDataset(\n", - " torch.tensor(train, dtype=torch.float32),\n", - " torch.tensor(Y_train, dtype=torch.float32))\n", + " torch.tensor(train, dtype=torch.float32), torch.tensor(Y_train, dtype=torch.float32)\n", + ")\n", "train_loader = DataLoader(train_set, batch_size=mbsize, shuffle=True)\n", "val_x = torch.tensor(val, dtype=torch.float32, device=device)\n", "val_y = torch.tensor(Y_val, dtype=torch.float32, device=device)\n", @@ -601,8 +615,8 @@ " # Calculate validation loss.\n", " val_loss = loss_fn(model(val_x), val_y).item()\n", " if verbose:\n", - " print('{}Epoch = {}{}'.format('-' * 10, epoch + 1, '-' * 10))\n", - " print('Val loss = {:.4f}'.format(val_loss))\n", + " print(\"{}Epoch = {}{}\".format(\"-\" * 10, epoch + 1, \"-\" * 10))\n", + " print(\"Val loss = {:.4f}\".format(val_loss))\n", "\n", " # Check convergence criterion.\n", " if val_loss < min_criterion:\n", @@ -611,7 +625,7 @@ " best_model = deepcopy(model)\n", " elif (epoch - min_epoch) == lookback:\n", " if verbose:\n", - " print('Stopping early')\n", + " print(\"Stopping early\")\n", " break\n", "\n", "# Keep best model\n", @@ -638,8 +652,8 @@ "base_mse = nn.MSELoss()(mean.repeat(len(test_y), 1), test_y)\n", "mse = nn.MSELoss()(model(test_x), test_y)\n", "\n", - "print('Base rate MSE = {:.2f}'.format(base_mse))\n", - "print('Model MSE = {:.2f}'.format(mse))" + "print(\"Base rate MSE = {:.2f}\".format(base_mse))\n", + "print(\"Model MSE = {:.2f}\".format(mse))" ] }, { @@ -679,7 +693,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, test[:512])\n", - "estimator = sage.PermutationEstimator(imputer, 'mse')\n", + "estimator = sage.PermutationEstimator(imputer, \"mse\")\n", "sage_values = estimator(test, Y_test)" ] }, @@ -721,12 +735,17 @@ "source": [ "# Feature groups\n", "feature_groups = group_names = {\n", - " 'location (grouped)': ['latitude', 'longitude', 'neighbourhood', 'neighbourhood_group'],\n", - " 'name (grouped)': ['name_words', 'name_length', 'name_isupper'],\n", - " 'reviews (grouped)': ['last_review', 'reviews_per_month', 'number_of_reviews'],\n", - " 'host (grouped)': ['host_gender', 'calculated_host_listings_count'],\n", - " 'availability': ['availability_365'],\n", - " 'room_type': ['room_type']\n", + " \"location (grouped)\": [\n", + " \"latitude\",\n", + " \"longitude\",\n", + " \"neighbourhood\",\n", + " \"neighbourhood_group\",\n", + " ],\n", + " \"name (grouped)\": [\"name_words\", \"name_length\", \"name_isupper\"],\n", + " \"reviews (grouped)\": [\"last_review\", \"reviews_per_month\", \"number_of_reviews\"],\n", + " \"host (grouped)\": [\"host_gender\", \"calculated_host_listings_count\"],\n", + " \"availability\": [\"availability_365\"],\n", + " \"room_type\": [\"room_type\"],\n", "}\n", "group_names = [group for group in feature_groups]\n", "for col in feature_names:\n", @@ -772,7 +791,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.GroupedMarginalImputer(model, test[:512], groups)\n", - "estimator = sage.PermutationEstimator(imputer, 'mse')\n", + "estimator = sage.PermutationEstimator(imputer, \"mse\")\n", "sage_values = estimator(test, Y_test)" ] }, diff --git a/notebooks/airbnb_tf.ipynb b/notebooks/airbnb_tf.ipynb index 6d9b0ae..c5de646 100644 --- a/notebooks/airbnb_tf.ipynb +++ b/notebooks/airbnb_tf.ipynb @@ -16,11 +16,13 @@ "outputs": [], "source": [ "import re\n", - "import sage\n", + "\n", + "import gender_guesser.detector as detector\n", "import numpy as np\n", "import pandas as pd\n", - "import gender_guesser.detector as detector\n", - "from sklearn.model_selection import train_test_split" + "from sklearn.model_selection import train_test_split\n", + "\n", + "import sage" ] }, { @@ -217,7 +219,7 @@ "outputs": [], "source": [ "# Categorical features\n", - "categorical_columns = ['neighbourhood_group', 'neighbourhood', 'room_type']\n", + "categorical_columns = [\"neighbourhood_group\", \"neighbourhood\", \"room_type\"]\n", "for column in categorical_columns:\n", " df[column] = pd.Categorical(df[column]).codes" ] @@ -229,7 +231,7 @@ "outputs": [], "source": [ "# Exclude outliers (top 0.5%)\n", - "df = df[df['price'] < df['price'].quantile(0.995)]" + "df = df[df[\"price\"] < df[\"price\"].quantile(0.995)]" ] }, { @@ -239,9 +241,9 @@ "outputs": [], "source": [ "# Features derived from name\n", - "df['name_length'] = df['name'].apply(lambda x: len(x))\n", - "df['name_isupper'] = df['name'].apply(lambda x: int(x.isupper()))\n", - "df['name_words'] = df['name'].apply(lambda x: len(re.findall(r'\\w+', x)))" + "df[\"name_length\"] = df[\"name\"].apply(lambda x: len(x))\n", + "df[\"name_isupper\"] = df[\"name\"].apply(lambda x: int(x.isupper()))\n", + "df[\"name_words\"] = df[\"name\"].apply(lambda x: len(re.findall(r\"\\w+\", x)))" ] }, { @@ -252,8 +254,8 @@ "source": [ "# Host gender guess\n", "guesser = detector.Detector()\n", - "df['host_gender'] = df['host_name'].apply(lambda x: guesser.get_gender(x.split(' ')[0]))\n", - "df['host_gender'] = pd.Categorical(df['host_gender']).codes" + "df[\"host_gender\"] = df[\"host_name\"].apply(lambda x: guesser.get_gender(x.split(\" \")[0]))\n", + "df[\"host_gender\"] = pd.Categorical(df[\"host_gender\"]).codes" ] }, { @@ -263,10 +265,12 @@ "outputs": [], "source": [ "# Number of days since last review\n", - "most_recent = df['last_review'].max()\n", - "df['last_review'] = (most_recent - df['last_review']).dt.days\n", - "df['last_review'] = (df['last_review'] - df['last_review'].mean()) / df['last_review'].std()\n", - "df['last_review'] = df['last_review'].fillna(-5)" + "most_recent = df[\"last_review\"].max()\n", + "df[\"last_review\"] = (most_recent - df[\"last_review\"]).dt.days\n", + "df[\"last_review\"] = (df[\"last_review\"] - df[\"last_review\"].mean()) / df[\n", + " \"last_review\"\n", + "].std()\n", + "df[\"last_review\"] = df[\"last_review\"].fillna(-5)" ] }, { @@ -276,7 +280,7 @@ "outputs": [], "source": [ "# Missing values\n", - "df['reviews_per_month'] = df['reviews_per_month'].fillna(0)" + "df[\"reviews_per_month\"] = df[\"reviews_per_month\"].fillna(0)" ] }, { @@ -286,9 +290,15 @@ "outputs": [], "source": [ "# Normalize other numerical features\n", - "df['number_of_reviews'] = (df['number_of_reviews'] - df['number_of_reviews'].mean()) / df['number_of_reviews'].std()\n", - "df['availability_365'] = (df['availability_365'] - df['availability_365'].mean()) / df['availability_365'].std()\n", - "df['name_length'] = (df['name_length'] - df['name_length'].mean()) / df['name_length'].std()" + "df[\"number_of_reviews\"] = (\n", + " df[\"number_of_reviews\"] - df[\"number_of_reviews\"].mean()\n", + ") / df[\"number_of_reviews\"].std()\n", + "df[\"availability_365\"] = (df[\"availability_365\"] - df[\"availability_365\"].mean()) / df[\n", + " \"availability_365\"\n", + "].std()\n", + "df[\"name_length\"] = (df[\"name_length\"] - df[\"name_length\"].mean()) / df[\n", + " \"name_length\"\n", + "].std()" ] }, { @@ -298,8 +308,8 @@ "outputs": [], "source": [ "# Normalize latitude and longitude\n", - "df['latitude'] = (df['latitude'] - df['latitude'].mean()) / df['latitude'].std()\n", - "df['longitude'] = (df['longitude'] - df['longitude'].mean()) / df['longitude'].std()" + "df[\"latitude\"] = (df[\"latitude\"] - df[\"latitude\"].mean()) / df[\"latitude\"].std()\n", + "df[\"longitude\"] = (df[\"longitude\"] - df[\"longitude\"].mean()) / df[\"longitude\"].std()" ] }, { @@ -309,7 +319,7 @@ "outputs": [], "source": [ "# Drop columns\n", - "df = df.drop(['id', 'host_id', 'host_name', 'name'], axis=1)" + "df = df.drop([\"id\", \"host_id\", \"host_name\", \"name\"], axis=1)" ] }, { @@ -503,7 +513,7 @@ "outputs": [], "source": [ "# Rearrange columns\n", - "target_col = 'price'\n", + "target_col = \"price\"\n", "cols = df.columns.tolist()\n", "del cols[cols.index(target_col)]\n", "cols.append(target_col)\n", @@ -512,9 +522,11 @@ "\n", "# Split data\n", "train, test = train_test_split(\n", - " df.values, test_size=int(0.1 * len(df.values)), random_state=0)\n", + " df.values, test_size=int(0.1 * len(df.values)), random_state=0\n", + ")\n", "train, val = train_test_split(\n", - " train, test_size=int(0.1 * len(df.values)), random_state=0)\n", + " train, test_size=int(0.1 * len(df.values)), random_state=0\n", + ")\n", "Y_train = train[:, -1:].copy()\n", "Y_val = val[:, -1:].copy()\n", "Y_test = test[:, -1:].copy()\n", @@ -560,11 +572,13 @@ "val_dataset = val_dataset.batch(batch_size)\n", "\n", "# Get model\n", - "model = keras.Sequential([\n", - " layers.Dense(256, activation='relu', input_shape=(train.shape[1],)),\n", - " layers.Dense(256, activation='relu'),\n", - " layers.Dense(1)],\n", - " name='airbnb_model'\n", + "model = keras.Sequential(\n", + " [\n", + " layers.Dense(256, activation=\"relu\", input_shape=(train.shape[1],)),\n", + " layers.Dense(256, activation=\"relu\"),\n", + " layers.Dense(1),\n", + " ],\n", + " name=\"airbnb_model\",\n", ")\n", "\n", "# Instantiate optimizer\n", @@ -577,6 +591,7 @@ "train_acc_metric = keras.metrics.MeanSquaredError()\n", "val_acc_metric = keras.metrics.MeanSquaredError()\n", "\n", + "\n", "# Training and validation utils\n", "@tf.function\n", "def train_step(x, y):\n", @@ -588,6 +603,7 @@ " train_acc_metric.update_state(y, preds)\n", " return loss_value\n", "\n", + "\n", "@tf.function\n", "def test_step(x, y):\n", " preds = model(x)\n", @@ -603,7 +619,6 @@ "epochs = 50\n", "\n", "for epoch in range(epochs):\n", - "\n", " # Iterate over data minibatches\n", " for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n", " loss_value = train_step(x_batch_train, y_batch_train)\n", @@ -619,7 +634,7 @@ " val_acc = val_acc_metric.result()\n", " val_acc_metric.reset_states()\n", "\n", - "# For classification (which is not the case here): see \n", + "# For classification (which is not the case here): see\n", "# https://github.com/iancovert/sage/blob/master/sage/utils.py#L36,\n", "# as output activations should already be applied properly.\n", "# probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])" @@ -681,7 +696,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, test[:512])\n", - "estimator = sage.PermutationEstimator(imputer, 'mse')\n", + "estimator = sage.PermutationEstimator(imputer, \"mse\")\n", "sage_values = estimator(test, Y_test)" ] }, @@ -723,12 +738,17 @@ "source": [ "# Feature groups\n", "feature_groups = group_names = {\n", - " 'location (grouped)': ['latitude', 'longitude', 'neighbourhood', 'neighbourhood_group'],\n", - " 'name (grouped)': ['name_words', 'name_length', 'name_isupper'],\n", - " 'reviews (grouped)': ['last_review', 'reviews_per_month', 'number_of_reviews'],\n", - " 'host (grouped)': ['host_gender', 'calculated_host_listings_count'],\n", - " 'availability': ['availability_365'],\n", - " 'room_type': ['room_type']\n", + " \"location (grouped)\": [\n", + " \"latitude\",\n", + " \"longitude\",\n", + " \"neighbourhood\",\n", + " \"neighbourhood_group\",\n", + " ],\n", + " \"name (grouped)\": [\"name_words\", \"name_length\", \"name_isupper\"],\n", + " \"reviews (grouped)\": [\"last_review\", \"reviews_per_month\", \"number_of_reviews\"],\n", + " \"host (grouped)\": [\"host_gender\", \"calculated_host_listings_count\"],\n", + " \"availability\": [\"availability_365\"],\n", + " \"room_type\": [\"room_type\"],\n", "}\n", "group_names = [group for group in feature_groups]\n", "for col in feature_names:\n", @@ -774,7 +794,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.GroupedMarginalImputer(model, test[:512], groups)\n", - "estimator = sage.PermutationEstimator(imputer, 'mse')\n", + "estimator = sage.PermutationEstimator(imputer, \"mse\")\n", "sage_values = estimator(test, Y_test)" ] }, diff --git a/notebooks/bank.ipynb b/notebooks/bank.ipynb index 4651595..25e9394 100644 --- a/notebooks/bank.ipynb +++ b/notebooks/bank.ipynb @@ -15,8 +15,9 @@ "metadata": {}, "outputs": [], "source": [ - "import sage\n", - "from sklearn.model_selection import train_test_split" + "from sklearn.model_selection import train_test_split\n", + "\n", + "import sage" ] }, { @@ -30,8 +31,17 @@ "\n", "# Feature names and categorical columns (for CatBoost model)\n", "feature_names = df.columns.tolist()[:-1]\n", - "categorical_cols = ['Job', 'Marital', 'Education', 'Default', 'Housing',\n", - " 'Loan', 'Contact', 'Month', 'Prev Outcome']\n", + "categorical_cols = [\n", + " \"Job\",\n", + " \"Marital\",\n", + " \"Education\",\n", + " \"Default\",\n", + " \"Housing\",\n", + " \"Loan\",\n", + " \"Contact\",\n", + " \"Month\",\n", + " \"Prev Outcome\",\n", + "]\n", "categorical_inds = [feature_names.index(col) for col in categorical_cols]" ] }, @@ -43,9 +53,11 @@ "source": [ "# Split data\n", "train, test = train_test_split(\n", - " df.values, test_size=int(0.1 * len(df.values)), random_state=123)\n", + " df.values, test_size=int(0.1 * len(df.values)), random_state=123\n", + ")\n", "train, val = train_test_split(\n", - " train, test_size=int(0.1 * len(df.values)), random_state=123)\n", + " train, test_size=int(0.1 * len(df.values)), random_state=123\n", + ")\n", "Y_train = train[:, -1].copy().astype(int)\n", "Y_val = val[:, -1].copy().astype(int)\n", "Y_test = test[:, -1].copy().astype(int)\n", @@ -67,10 +79,10 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from sklearn.metrics import log_loss\n", - "from catboost import CatBoostClassifier" + "import numpy as np\n", + "from catboost import CatBoostClassifier\n", + "from sklearn.metrics import log_loss" ] }, { @@ -79,12 +91,11 @@ "metadata": {}, "outputs": [], "source": [ - "model = CatBoostClassifier(iterations=100,\n", - " learning_rate=0.3,\n", - " depth=10)\n", + "model = CatBoostClassifier(iterations=100, learning_rate=0.3, depth=10)\n", "\n", - "model = model.fit(train, Y_train, categorical_inds, eval_set=(val, Y_val),\n", - " verbose=False)" + "model = model.fit(\n", + " train, Y_train, categorical_inds, eval_set=(val, Y_val), verbose=False\n", + ")" ] }, { @@ -115,19 +126,22 @@ "\n", "# Plot\n", "plt.figure(figsize=(9, 6))\n", - "plt.bar(np.arange(4), [base_ce, train_ce, val_ce, test_ce],\n", - " color=['tab:blue', 'tab:orange', 'tab:purple', 'tab:green'])\n", + "plt.bar(\n", + " np.arange(4),\n", + " [base_ce, train_ce, val_ce, test_ce],\n", + " color=[\"tab:blue\", \"tab:orange\", \"tab:purple\", \"tab:green\"],\n", + ")\n", "\n", "ax = plt.gca()\n", "for i, ce in enumerate([base_ce, train_ce, val_ce, test_ce]):\n", - " ax.text(i - 0.17, ce + 0.005, '{:.3f}'.format(ce), fontsize=16)\n", - " \n", + " ax.text(i - 0.17, ce + 0.005, \"{:.3f}\".format(ce), fontsize=16)\n", + "\n", "plt.ylim(0, 0.4)\n", "\n", - "plt.xticks(np.arange(4), ['Base Rate', 'Train', 'Val', 'Test'])\n", + "plt.xticks(np.arange(4), [\"Base Rate\", \"Train\", \"Val\", \"Test\"])\n", "plt.tick_params(labelsize=16)\n", - "plt.ylabel('Cross Entropy Loss', fontsize=18)\n", - "plt.title('Performance comparison', fontsize=20)\n", + "plt.ylabel(\"Cross Entropy Loss\", fontsize=18)\n", + "plt.title(\"Performance comparison\", fontsize=20)\n", "\n", "plt.tight_layout()\n", "plt.show()" @@ -193,7 +207,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, train[:512])\n", - "estimator = sage.KernelEstimator(imputer, 'cross entropy')\n", + "estimator = sage.KernelEstimator(imputer, \"cross entropy\")\n", "sage_train = estimator(train, Y_train, thresh=0.025)\n", "sage_val = estimator(val, Y_val, thresh=0.025)\n", "sage_test = estimator(test, Y_test, thresh=0.025)" @@ -218,11 +232,13 @@ } ], "source": [ - "sage.comparison_plot((sage_train, sage_val, sage_test),\n", - " ('Train', 'Val', 'Test'),\n", - " feature_names,\n", - " colors=('tab:orange', 'tab:purple', 'tab:green'),\n", - " title='Train vs. Val vs. Test')" + "sage.comparison_plot(\n", + " (sage_train, sage_val, sage_test),\n", + " (\"Train\", \"Val\", \"Test\"),\n", + " feature_names,\n", + " colors=(\"tab:orange\", \"tab:purple\", \"tab:green\"),\n", + " title=\"Train vs. Val vs. Test\",\n", + ")" ] }, { @@ -243,7 +259,7 @@ "source": [ "# Convert duration to seconds\n", "test_seconds = test.copy()\n", - "duration_index = feature_names.index('Duration')\n", + "duration_index = feature_names.index(\"Duration\")\n", "test_seconds[:, duration_index] = test_seconds[:, duration_index] * 60\n", "\n", "# Convert duration to hours\n", @@ -252,12 +268,24 @@ "\n", "# Shift months by one\n", "test_month = test.copy()\n", - "month_index = feature_names.index('Month')\n", - "months = ['jan', 'feb', 'mar', 'apr', 'may', 'jun',\n", - " 'jul', 'aug', 'sep', 'oct', 'nov', 'dec']\n", + "month_index = feature_names.index(\"Month\")\n", + "months = [\n", + " \"jan\",\n", + " \"feb\",\n", + " \"mar\",\n", + " \"apr\",\n", + " \"may\",\n", + " \"jun\",\n", + " \"jul\",\n", + " \"aug\",\n", + " \"sep\",\n", + " \"oct\",\n", + " \"nov\",\n", + " \"dec\",\n", + "]\n", "test_month[:, month_index] = list(\n", - " map(lambda x: months[(months.index(x) + 1) % 12],\n", - " test_month[:, month_index]))" + " map(lambda x: months[(months.index(x) + 1) % 12], test_month[:, month_index])\n", + ")" ] }, { @@ -289,22 +317,34 @@ "\n", "# Plot\n", "plt.figure(figsize=(9, 7))\n", - "plt.bar(np.arange(5), [base_ce, val_ce, seconds_ce, hours_ce, month_ce],\n", - " color=['tab:blue', 'tab:purple', 'crimson', 'firebrick', 'indianred'])\n", + "plt.bar(\n", + " np.arange(5),\n", + " [base_ce, val_ce, seconds_ce, hours_ce, month_ce],\n", + " color=[\"tab:blue\", \"tab:purple\", \"crimson\", \"firebrick\", \"indianred\"],\n", + ")\n", "\n", "ax = plt.gca()\n", "for i, ce in enumerate([base_ce, val_ce, seconds_ce, hours_ce, month_ce]):\n", - " ax.text(i - 0.17, ce + 0.007, '{:.3f}'.format(ce), fontsize=16)\n", - " \n", + " ax.text(i - 0.17, ce + 0.007, \"{:.3f}\".format(ce), fontsize=16)\n", + "\n", "plt.ylim(0, 0.94)\n", "\n", - "plt.xticks(np.arange(5),\n", - " ['Base Rate', 'Validation', r'Duration $\\rightarrow$ Secs',\n", - " r'Duration $\\rightarrow$ Hours', r'Month $\\rightarrow$ + 1'],\n", - " rotation=45, rotation_mode='anchor', ha='right')\n", + "plt.xticks(\n", + " np.arange(5),\n", + " [\n", + " \"Base Rate\",\n", + " \"Validation\",\n", + " r\"Duration $\\rightarrow$ Secs\",\n", + " r\"Duration $\\rightarrow$ Hours\",\n", + " r\"Month $\\rightarrow$ + 1\",\n", + " ],\n", + " rotation=45,\n", + " rotation_mode=\"anchor\",\n", + " ha=\"right\",\n", + ")\n", "plt.tick_params(labelsize=16)\n", - "plt.ylabel('Cross Entropy Loss', fontsize=18)\n", - "plt.title('Performance comparison', fontsize=20)\n", + "plt.ylabel(\"Cross Entropy Loss\", fontsize=18)\n", + "plt.title(\"Performance comparison\", fontsize=20)\n", "\n", "plt.tight_layout()\n", "plt.show()" @@ -370,7 +410,7 @@ "source": [ "# Calculate feature importance for perturbed data\n", "imputer = sage.MarginalImputer(model, val[:512])\n", - "estimator = sage.SignEstimator(imputer, 'cross entropy')\n", + "estimator = sage.SignEstimator(imputer, \"cross entropy\")\n", "sage_seconds = estimator(test_seconds, Y_test)\n", "sage_hours = estimator(test_hours, Y_test)\n", "sage_month = estimator(test_month, Y_test)" @@ -395,7 +435,9 @@ } ], "source": [ - "sage_seconds.plot_sign(feature_names, title=r'Feature Importance (Duration $\\rightarrow$ Seconds)')" + "sage_seconds.plot_sign(\n", + " feature_names, title=r\"Feature Importance (Duration $\\rightarrow$ Seconds)\"\n", + ")" ] }, { @@ -417,7 +459,9 @@ } ], "source": [ - "sage_hours.plot_sign(feature_names, title=r'Feature Importance (Duration $\\rightarrow$ Hours)')" + "sage_hours.plot_sign(\n", + " feature_names, title=r\"Feature Importance (Duration $\\rightarrow$ Hours)\"\n", + ")" ] }, { @@ -439,7 +483,9 @@ } ], "source": [ - "sage_month.plot_sign(feature_names, title=r'Feature Importance (Month $\\rightarrow$ + 1)')" + "sage_month.plot_sign(\n", + " feature_names, title=r\"Feature Importance (Month $\\rightarrow$ + 1)\"\n", + ")" ] }, { diff --git a/notebooks/bike.ipynb b/notebooks/bike.ipynb index 05b44bb..dddbaf6 100644 --- a/notebooks/bike.ipynb +++ b/notebooks/bike.ipynb @@ -15,9 +15,10 @@ "metadata": {}, "outputs": [], "source": [ - "import sage\n", "import numpy as np\n", - "from sklearn.model_selection import train_test_split" + "from sklearn.model_selection import train_test_split\n", + "\n", + "import sage" ] }, { @@ -39,9 +40,11 @@ "source": [ "# Split data, with total count serving as regression target\n", "train, test = train_test_split(\n", - " df.values, test_size=int(0.1 * len(df.values)), random_state=123)\n", + " df.values, test_size=int(0.1 * len(df.values)), random_state=123\n", + ")\n", "train, val = train_test_split(\n", - " train, test_size=int(0.1 * len(df.values)), random_state=123)\n", + " train, test_size=int(0.1 * len(df.values)), random_state=123\n", + ")\n", "Y_train = train[:, -1].copy()\n", "Y_val = val[:, -1].copy()\n", "Y_test = test[:, -1].copy()\n", @@ -77,12 +80,8 @@ "dval = xgb.DMatrix(val, label=Y_val)\n", "\n", "# Parameters\n", - "param = {\n", - " 'max_depth' : 10,\n", - " 'objective': 'reg:squarederror',\n", - " 'nthread': 4\n", - "}\n", - "evallist = [(dtrain, 'train'), (dval, 'val')]\n", + "param = {\"max_depth\": 10, \"objective\": \"reg:squarederror\", \"nthread\": 4}\n", + "evallist = [(dtrain, \"train\"), (dval, \"val\")]\n", "num_round = 50\n", "\n", "# Train\n", @@ -109,8 +108,8 @@ "base_mse = np.mean((mean - Y_test) ** 2)\n", "mse = np.mean((model.predict(xgb.DMatrix(test)) - Y_test) ** 2)\n", "\n", - "print('Base rate MSE = {:.2f}'.format(base_mse))\n", - "print('Model MSE = {:.2f}'.format(mse))" + "print(\"Base rate MSE = {:.2f}\".format(base_mse))\n", + "print(\"Model MSE = {:.2f}\".format(mse))" ] }, { @@ -143,7 +142,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, test[:512])\n", - "estimator = sage.PermutationEstimator(imputer, 'mse')\n", + "estimator = sage.PermutationEstimator(imputer, \"mse\")\n", "sage_values = estimator(test, Y_test)" ] }, @@ -207,7 +206,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, test[:512])\n", - "estimator = sage.PermutationEstimator(imputer, 'mse')\n", + "estimator = sage.PermutationEstimator(imputer, \"mse\")\n", "sensitivity = estimator(test)" ] }, @@ -231,7 +230,7 @@ ], "source": [ "# Plot results\n", - "sensitivity.plot(feature_names, title='Model Sensitivity')" + "sensitivity.plot(feature_names, title=\"Model Sensitivity\")" ] }, { diff --git a/notebooks/calibration.ipynb b/notebooks/calibration.ipynb index 316cbb0..4f3f090 100644 --- a/notebooks/calibration.ipynb +++ b/notebooks/calibration.ipynb @@ -17,10 +17,11 @@ "metadata": {}, "outputs": [], "source": [ - "import sage\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from sklearn.model_selection import train_test_split" + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "import sage" ] }, { @@ -44,9 +45,11 @@ "source": [ "# Split data, with total count serving as regression target\n", "train, test = train_test_split(\n", - " df.values, test_size=int(0.1 * len(df.values)), random_state=123)\n", + " df.values, test_size=int(0.1 * len(df.values)), random_state=123\n", + ")\n", "train, val = train_test_split(\n", - " train, test_size=int(0.1 * len(df.values)), random_state=123)\n", + " train, test_size=int(0.1 * len(df.values)), random_state=123\n", + ")\n", "Y_train = train[:, -1].copy()\n", "Y_val = val[:, -1].copy()\n", "Y_test = test[:, -1].copy()\n", @@ -85,12 +88,8 @@ "dval = xgb.DMatrix(val, label=Y_val)\n", "\n", "# Parameters\n", - "param = {\n", - " 'max_depth' : 10,\n", - " 'objective': 'reg:squarederror',\n", - " 'nthread': 4\n", - "}\n", - "evallist = [(dtrain, 'train'), (dval, 'val')]\n", + "param = {\"max_depth\": 10, \"objective\": \"reg:squarederror\", \"nthread\": 4}\n", + "evallist = [(dtrain, \"train\"), (dval, \"val\")]\n", "num_round = 50\n", "\n", "# Train\n", @@ -118,8 +117,8 @@ "base_mse = np.mean((mean - Y_test) ** 2)\n", "mse = np.mean((model.predict(xgb.DMatrix(test)) - Y_test) ** 2)\n", "\n", - "print('Base rate MSE = {:.2f}'.format(base_mse))\n", - "print('Model MSE = {:.2f}'.format(mse))" + "print(\"Base rate MSE = {:.2f}\".format(base_mse))\n", + "print(\"Model MSE = {:.2f}\".format(mse))" ] }, { @@ -139,12 +138,16 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, test[:512])\n", - "estimator = sage.PermutationEstimator(imputer, 'mse')\n", + "estimator = sage.PermutationEstimator(imputer, \"mse\")\n", "\n", "# Run explainer multiple times\n", "sage_list = []\n", "for i in range(50):\n", - " sage_list.append(estimator(test, Y_test, n_permutations=512*5, bar=False, detect_convergence=False))" + " sage_list.append(\n", + " estimator(\n", + " test, Y_test, n_permutations=512 * 5, bar=False, detect_convergence=False\n", + " )\n", + " )" ] }, { @@ -176,11 +179,19 @@ "\n", "# Y = X line\n", "m = max(std_est.max(), std_real.max())\n", - "plt.plot([0, m,], [0, m], color='black', linestyle=':')\n", + "plt.plot(\n", + " [\n", + " 0,\n", + " m,\n", + " ],\n", + " [0, m],\n", + " color=\"black\",\n", + " linestyle=\":\",\n", + ")\n", "\n", - "plt.xlabel('StdDev Estimated')\n", - "plt.ylabel('StdDev Observed')\n", - "plt.title('Permutation Estimator Calibration')\n", + "plt.xlabel(\"StdDev Estimated\")\n", + "plt.ylabel(\"StdDev Observed\")\n", + "plt.title(\"Permutation Estimator Calibration\")\n", "\n", "plt.tight_layout()\n", "plt.show()" @@ -203,13 +214,22 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, test[:512])\n", - "estimator = sage.PermutationEstimator(imputer, 'mse')\n", + "estimator = sage.PermutationEstimator(imputer, \"mse\")\n", "\n", "# Run explainer multiple times\n", "sage_list = []\n", "for i in range(50):\n", - " sage_list.append(estimator(test, Y_test, n_permutations=512*5, bar=False, detect_convergence=False,\n", - " min_coalition=3, max_coalition=9))" + " sage_list.append(\n", + " estimator(\n", + " test,\n", + " Y_test,\n", + " n_permutations=512 * 5,\n", + " bar=False,\n", + " detect_convergence=False,\n", + " min_coalition=3,\n", + " max_coalition=9,\n", + " )\n", + " )" ] }, { @@ -241,11 +261,19 @@ "\n", "# Y = X line\n", "m = max(std_est.max(), std_real.max())\n", - "plt.plot([0, m,], [0, m], color='black', linestyle=':')\n", + "plt.plot(\n", + " [\n", + " 0,\n", + " m,\n", + " ],\n", + " [0, m],\n", + " color=\"black\",\n", + " linestyle=\":\",\n", + ")\n", "\n", - "plt.xlabel('StdDev Estimated')\n", - "plt.ylabel('StdDev Observed')\n", - "plt.title('Relaxed Permutation Estimator Calibration')\n", + "plt.xlabel(\"StdDev Estimated\")\n", + "plt.ylabel(\"StdDev Observed\")\n", + "plt.title(\"Relaxed Permutation Estimator Calibration\")\n", "\n", "plt.tight_layout()\n", "plt.show()" @@ -268,12 +296,14 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, test[:512])\n", - "estimator = sage.IteratedEstimator(imputer, 'mse')\n", + "estimator = sage.IteratedEstimator(imputer, \"mse\")\n", "\n", "# Run explainer multiple times\n", "sage_list = []\n", "for i in range(50):\n", - " sage_list.append(estimator(test, Y_test, n_samples=512, bar=False, detect_convergence=False))" + " sage_list.append(\n", + " estimator(test, Y_test, n_samples=512, bar=False, detect_convergence=False)\n", + " )" ] }, { @@ -305,11 +335,19 @@ "\n", "# Y = X line\n", "m = max(std_est.max(), std_real.max())\n", - "plt.plot([0, m,], [0, m], color='black', linestyle=':')\n", + "plt.plot(\n", + " [\n", + " 0,\n", + " m,\n", + " ],\n", + " [0, m],\n", + " color=\"black\",\n", + " linestyle=\":\",\n", + ")\n", "\n", - "plt.xlabel('StdDev Estimated')\n", - "plt.ylabel('StdDev Observed')\n", - "plt.title('Iterated Estimator Calibration')\n", + "plt.xlabel(\"StdDev Estimated\")\n", + "plt.ylabel(\"StdDev Observed\")\n", + "plt.title(\"Iterated Estimator Calibration\")\n", "\n", "plt.tight_layout()\n", "plt.show()" @@ -332,12 +370,14 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, test[:512])\n", - "estimator = sage.KernelEstimator(imputer, 'mse')\n", + "estimator = sage.KernelEstimator(imputer, \"mse\")\n", "\n", "# Run explainer multiple times\n", "sage_list = []\n", "for i in range(50):\n", - " sage_list.append(estimator(test, Y_test, n_samples=20000, bar=False, detect_convergence=False))" + " sage_list.append(\n", + " estimator(test, Y_test, n_samples=20000, bar=False, detect_convergence=False)\n", + " )" ] }, { @@ -369,11 +409,19 @@ "\n", "# Y = X line\n", "m = max(std_est.max(), std_real.max())\n", - "plt.plot([0, m,], [0, m], color='black', linestyle=':')\n", + "plt.plot(\n", + " [\n", + " 0,\n", + " m,\n", + " ],\n", + " [0, m],\n", + " color=\"black\",\n", + " linestyle=\":\",\n", + ")\n", "\n", - "plt.xlabel('StdDev Estimated')\n", - "plt.ylabel('StdDev Observed')\n", - "plt.title('Kernel Estimator Calibration')\n", + "plt.xlabel(\"StdDev Estimated\")\n", + "plt.ylabel(\"StdDev Observed\")\n", + "plt.title(\"Kernel Estimator Calibration\")\n", "\n", "plt.tight_layout()\n", "plt.show()" diff --git a/notebooks/consistency.ipynb b/notebooks/consistency.ipynb index 1c73315..643070d 100644 --- a/notebooks/consistency.ipynb +++ b/notebooks/consistency.ipynb @@ -17,10 +17,11 @@ "metadata": {}, "outputs": [], "source": [ - "import sage\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from sklearn.model_selection import train_test_split" + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "import sage" ] }, { @@ -44,9 +45,11 @@ "source": [ "# Split data, with total count serving as regression target\n", "train, test = train_test_split(\n", - " df.values, test_size=int(0.1 * len(df.values)), random_state=123)\n", + " df.values, test_size=int(0.1 * len(df.values)), random_state=123\n", + ")\n", "train, val = train_test_split(\n", - " train, test_size=int(0.1 * len(df.values)), random_state=123)\n", + " train, test_size=int(0.1 * len(df.values)), random_state=123\n", + ")\n", "Y_train = train[:, -1].copy()\n", "Y_val = val[:, -1].copy()\n", "Y_test = test[:, -1].copy()\n", @@ -85,12 +88,8 @@ "dval = xgb.DMatrix(val, label=Y_val)\n", "\n", "# Parameters\n", - "param = {\n", - " 'max_depth' : 10,\n", - " 'objective': 'reg:squarederror',\n", - " 'nthread': 4\n", - "}\n", - "evallist = [(dtrain, 'train'), (dval, 'val')]\n", + "param = {\"max_depth\": 10, \"objective\": \"reg:squarederror\", \"nthread\": 4}\n", + "evallist = [(dtrain, \"train\"), (dval, \"val\")]\n", "num_round = 50\n", "\n", "# Train\n", @@ -118,8 +117,8 @@ "base_mse = np.mean((mean - Y_test) ** 2)\n", "mse = np.mean((model.predict(xgb.DMatrix(test)) - Y_test) ** 2)\n", "\n", - "print('Base rate MSE = {:.2f}'.format(base_mse))\n", - "print('Model MSE = {:.2f}'.format(mse))" + "print(\"Base rate MSE = {:.2f}\".format(base_mse))\n", + "print(\"Model MSE = {:.2f}\".format(mse))" ] }, { @@ -149,10 +148,10 @@ "imputer = sage.MarginalImputer(model, test[:512])\n", "\n", "# Set up estimators\n", - "permutation_estimator = sage.PermutationEstimator(imputer, 'mse')\n", - "parallel_permutation_estimator = sage.PermutationEstimator(imputer, 'mse', n_jobs=-1)\n", - "iterated_estimator = sage.IteratedEstimator(imputer, 'mse')\n", - "kernel_estimator = sage.KernelEstimator(imputer, 'mse')" + "permutation_estimator = sage.PermutationEstimator(imputer, \"mse\")\n", + "parallel_permutation_estimator = sage.PermutationEstimator(imputer, \"mse\", n_jobs=-1)\n", + "iterated_estimator = sage.IteratedEstimator(imputer, \"mse\")\n", + "kernel_estimator = sage.KernelEstimator(imputer, \"mse\")" ] }, { @@ -314,14 +313,19 @@ ], "source": [ "explanations = [explanation1, explanation2, explanation3, explanation4]\n", - "names = ['Permutation Estimator', 'Parallel Permutation Estimator', 'Iterated Estimator', 'Kernel Estimator']\n", + "names = [\n", + " \"Permutation Estimator\",\n", + " \"Parallel Permutation Estimator\",\n", + " \"Iterated Estimator\",\n", + " \"Kernel Estimator\",\n", + "]\n", "\n", "for i in range(len(explanations)):\n", " for j in range(i + 1, len(explanations)):\n", " plt.figure()\n", - " \n", + "\n", " plt.scatter(explanations[i].values, explanations[j].values)\n", - " plt.plot([0, 18000], [0, 18000], linestyle=':', color='black')\n", + " plt.plot([0, 18000], [0, 18000], linestyle=\":\", color=\"black\")\n", " plt.xlabel(names[i])\n", " plt.ylabel(names[j])\n", " plt.tight_layout()\n", @@ -515,14 +519,19 @@ ], "source": [ "explanations = [explanation1, explanation2, explanation3, explanation4]\n", - "names = ['Permutation Estimator', 'Parallel Permutation Estimator', 'Iterated Estimator', 'Kernel Estimator']\n", + "names = [\n", + " \"Permutation Estimator\",\n", + " \"Parallel Permutation Estimator\",\n", + " \"Iterated Estimator\",\n", + " \"Kernel Estimator\",\n", + "]\n", "\n", "for i in range(len(explanations)):\n", " for j in range(i + 1, len(explanations)):\n", " plt.figure()\n", - " \n", + "\n", " plt.scatter(explanations[i].values, explanations[j].values)\n", - " plt.plot([0, 18000], [0, 18000], linestyle=':', color='black')\n", + " plt.plot([0, 18000], [0, 18000], linestyle=\":\", color=\"black\")\n", " plt.xlabel(names[i])\n", " plt.ylabel(names[j])\n", " plt.tight_layout()\n", diff --git a/notebooks/credit.ipynb b/notebooks/credit.ipynb index 0864b5f..dabda4d 100644 --- a/notebooks/credit.ipynb +++ b/notebooks/credit.ipynb @@ -15,8 +15,9 @@ "metadata": {}, "outputs": [], "source": [ - "import sage\n", - "from sklearn.model_selection import train_test_split" + "from sklearn.model_selection import train_test_split\n", + "\n", + "import sage" ] }, { @@ -31,10 +32,20 @@ "# Feature names and categorical columns (for CatBoost model)\n", "feature_names = df.columns.tolist()[:-1]\n", "categorical_columns = [\n", - " 'Checking Status', 'Credit History', 'Purpose', 'Credit Amount',\n", - " 'Savings Account/Bonds', 'Employment Since', 'Personal Status',\n", - " 'Debtors/Guarantors', 'Property Type', 'Other Installment Plans',\n", - " 'Housing Ownership', 'Job', 'Telephone', 'Foreign Worker'\n", + " \"Checking Status\",\n", + " \"Credit History\",\n", + " \"Purpose\",\n", + " \"Credit Amount\",\n", + " \"Savings Account/Bonds\",\n", + " \"Employment Since\",\n", + " \"Personal Status\",\n", + " \"Debtors/Guarantors\",\n", + " \"Property Type\",\n", + " \"Other Installment Plans\",\n", + " \"Housing Ownership\",\n", + " \"Job\",\n", + " \"Telephone\",\n", + " \"Foreign Worker\",\n", "]\n", "categorical_inds = [feature_names.index(col) for col in categorical_columns]" ] @@ -47,9 +58,11 @@ "source": [ "# Split data\n", "train, test = train_test_split(\n", - " df.values, test_size=int(0.1 * len(df.values)), random_state=0)\n", + " df.values, test_size=int(0.1 * len(df.values)), random_state=0\n", + ")\n", "train, val = train_test_split(\n", - " train, test_size=int(0.1 * len(df.values)), random_state=0)\n", + " train, test_size=int(0.1 * len(df.values)), random_state=0\n", + ")\n", "Y_train = train[:, -1].copy().astype(int)\n", "Y_val = val[:, -1].copy().astype(int)\n", "Y_test = test[:, -1].copy().astype(int)\n", @@ -72,8 +85,8 @@ "outputs": [], "source": [ "import numpy as np\n", - "from sklearn.metrics import log_loss\n", - "from catboost import CatBoostClassifier" + "from catboost import CatBoostClassifier\n", + "from sklearn.metrics import log_loss" ] }, { @@ -91,12 +104,11 @@ } ], "source": [ - "model = CatBoostClassifier(iterations=50,\n", - " learning_rate=0.3,\n", - " depth=3)\n", + "model = CatBoostClassifier(iterations=50, learning_rate=0.3, depth=3)\n", "\n", - "model = model.fit(train, Y_train, categorical_inds, eval_set=(val, Y_val),\n", - " verbose=False)" + "model = model.fit(\n", + " train, Y_train, categorical_inds, eval_set=(val, Y_val), verbose=False\n", + ")" ] }, { @@ -119,8 +131,8 @@ "base_ce = log_loss(Y_test.astype(int), p[np.newaxis].repeat(len(test), 0))\n", "ce = log_loss(Y_test.astype(int), model.predict_proba(test))\n", "\n", - "print('Base rate cross entropy = {:.3f}'.format(base_ce))\n", - "print('Model cross entropy = {:.3f}'.format(ce))" + "print(\"Base rate cross entropy = {:.3f}\".format(base_ce))\n", + "print(\"Model cross entropy = {:.3f}\".format(ce))" ] }, { @@ -153,7 +165,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model, train[:512])\n", - "estimator = sage.PermutationEstimator(imputer, 'cross entropy')\n", + "estimator = sage.PermutationEstimator(imputer, \"cross entropy\")\n", "sage_values = estimator(test, Y_test)" ] }, @@ -175,7 +187,7 @@ ], "source": [ "# Plot results\n", - "sage_values.plot(feature_names, title='Feature Importance (Marginal Sampling)')" + "sage_values.plot(feature_names, title=\"Feature Importance (Marginal Sampling)\")" ] }, { @@ -193,12 +205,13 @@ "metadata": {}, "outputs": [], "source": [ + "from copy import deepcopy\n", + "\n", + "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", - "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, BatchSampler\n", - "from copy import deepcopy\n", - "import matplotlib.pyplot as plt" + "from torch.utils.data import BatchSampler, DataLoader, RandomSampler, TensorDataset" ] }, { @@ -209,13 +222,14 @@ "source": [ "# Class for applying binary mask at input layer.\n", "class MaskLayer1d(nn.Module):\n", - " '''\n", + " \"\"\"\n", " Masking for 1d inputs.\n", "\n", " Args:\n", " append: whether to append the mask along channels dim.\n", " value: replacement value for held out features.\n", - " '''\n", + " \"\"\"\n", + "\n", " def __init__(self, append=True, value=0):\n", " super().__init__()\n", " self.append = append\n", @@ -230,7 +244,7 @@ "\n", "\n", "# Setup\n", - "device = torch.device('cuda')\n", + "device = torch.device(\"cuda\")\n", "num_features = len(feature_names)\n", "\n", "# Create surrogate model (outputs logits)\n", @@ -240,7 +254,8 @@ " nn.ELU(inplace=True),\n", " nn.Linear(128, 128),\n", " nn.ELU(inplace=True),\n", - " nn.Linear(128, 2)).to(device)" + " nn.Linear(128, 2),\n", + ").to(device)" ] }, { @@ -265,12 +280,8 @@ "outputs": [], "source": [ "# Prepare validation dataset\n", - "X_val_surrogate = torch.tensor(\n", - " ((val - mean) / std),\n", - " dtype=torch.float32)\n", - "Y_val_surrogate = torch.tensor(\n", - " model.predict_proba(val),\n", - " dtype=torch.float32)\n", + "X_val_surrogate = torch.tensor(((val - mean) / std), dtype=torch.float32)\n", + "Y_val_surrogate = torch.tensor(model.predict_proba(val), dtype=torch.float32)\n", "\n", "# Replicate rows\n", "X_val_surrogate = X_val_surrogate.repeat(1000, 1)\n", @@ -282,11 +293,12 @@ " num_included = np.random.choice(num_features + 1)\n", " S_val[i, num_included:] = 0\n", " S_val[i] = S_val[i, torch.randperm(num_features)]\n", - " \n", + "\n", "# Create dataset iterator\n", "val_set = TensorDataset(X_val_surrogate, Y_val_surrogate, S_val)\n", "val_loader = DataLoader(val_set, batch_size=25000)\n", "\n", + "\n", "# Function to measure validation performance\n", "def validate(model, loss_fn):\n", " with torch.no_grad():\n", @@ -319,8 +331,7 @@ " super().__init__()\n", "\n", " def forward(self, pred, target):\n", - " return - torch.mean(\n", - " torch.sum(pred.log_softmax(dim=1) * target, dim=1))" + " return -torch.mean(torch.sum(pred.log_softmax(dim=1) * target, dim=1))" ] }, { @@ -542,8 +553,10 @@ "\n", " # Data loader\n", " random_sampler = RandomSampler(\n", - " train_set, replacement=True,\n", - " num_samples=int(np.ceil(len(X_surrogate) / mbsize))*mbsize)\n", + " train_set,\n", + " replacement=True,\n", + " num_samples=int(np.ceil(len(X_surrogate) / mbsize)) * mbsize,\n", + " )\n", " batch_sampler = BatchSampler(random_sampler, batch_size=mbsize, drop_last=True)\n", " train_loader = DataLoader(train_set, batch_sampler=batch_sampler)\n", "\n", @@ -564,7 +577,7 @@ " S = torch.ones(mbsize, num_features, dtype=torch.float32, device=device)\n", " num_included = np.random.choice(num_features + 1, size=mbsize)\n", " for j in range(mbsize):\n", - " S[j, num_included[j]:] = 0\n", + " S[j, num_included[j] :] = 0\n", " S[j] = S[j, torch.randperm(num_features)]\n", "\n", " # Make predictions\n", @@ -579,9 +592,9 @@ " # End of epoch progress message\n", " val_loss = validate(surrogate, loss_fn).item()\n", " loss_list.append(val_loss)\n", - " print('----- Epoch = {} -----'.format(epoch + 1))\n", - " print('Val loss = {:.4f}'.format(val_loss))\n", - " print('')\n", + " print(\"----- Epoch = {} -----\".format(epoch + 1))\n", + " print(\"Val loss = {:.4f}\".format(val_loss))\n", + " print(\"\")\n", "\n", " # Check if best model\n", " if epoch >= min_epoch:\n", @@ -589,12 +602,12 @@ " best_epoch_loss = val_loss\n", " best_model = deepcopy(surrogate)\n", " best_epoch = epoch\n", - " print('New best epoch, val loss = {:.4f}'.format(val_loss))\n", - " print('')\n", + " print(\"New best epoch, val loss = {:.4f}\".format(val_loss))\n", + " print(\"\")\n", " else:\n", " # Check for early stopping\n", " if epoch - best_epoch == early_stop_epochs:\n", - " print('Stopping early')\n", + " print(\"Stopping early\")\n", " break\n", "\n", " surrogate = best_model" @@ -620,10 +633,10 @@ "# Plot loss during training\n", "plt.figure(figsize=(9, 6))\n", "plt.plot(loss_list)\n", - "plt.xlabel('Epochs', fontsize=18)\n", - "plt.ylabel('Cross entropy loss', fontsize=18)\n", + "plt.xlabel(\"Epochs\", fontsize=18)\n", + "plt.ylabel(\"Cross entropy loss\", fontsize=18)\n", "plt.tick_params(labelsize=16)\n", - "plt.title('Surrogate training', fontsize=20)\n", + "plt.title(\"Surrogate training\", fontsize=20)\n", "plt.show()" ] }, @@ -661,18 +674,24 @@ "class Imputer:\n", " def __init__(self):\n", " self.num_groups = num_features\n", - " \n", + "\n", " def __call__(self, x, S):\n", " # Call surrogate model (with data normalization)\n", - " return surrogate(\n", - " (torch.tensor((x - mean) / std, dtype=torch.float32, device=device),\n", - " torch.tensor(S, dtype=torch.float32, device=device))\n", - " ).softmax(dim=1).cpu().data.numpy()\n", + " return (\n", + " surrogate((\n", + " torch.tensor((x - mean) / std, dtype=torch.float32, device=device),\n", + " torch.tensor(S, dtype=torch.float32, device=device),\n", + " ))\n", + " .softmax(dim=1)\n", + " .cpu()\n", + " .data.numpy()\n", + " )\n", + "\n", "\n", "imputer = Imputer()\n", "\n", "# Calculate SAGE values\n", - "estimator = sage.PermutationEstimator(imputer, 'cross entropy')\n", + "estimator = sage.PermutationEstimator(imputer, \"cross entropy\")\n", "sage_values = estimator(test, Y_test)" ] }, @@ -694,7 +713,7 @@ ], "source": [ "# Plot results\n", - "sage_values.plot(feature_names, title='Feature Importance (Surrogate)')" + "sage_values.plot(feature_names, title=\"Feature Importance (Surrogate)\")" ] }, { diff --git a/notebooks/losses.ipynb b/notebooks/losses.ipynb index 7ccca95..05c0836 100644 --- a/notebooks/losses.ipynb +++ b/notebooks/losses.ipynb @@ -46,11 +46,12 @@ } ], "source": [ - "from catboost import CatBoostClassifier\n", "import numpy as np\n", - "import sage\n", + "from catboost import CatBoostClassifier\n", + "from sklearn.dummy import DummyClassifier\n", "from sklearn.metrics import accuracy_score, log_loss, zero_one_loss\n", - "from sklearn.dummy import DummyClassifier" + "\n", + "import sage" ] }, { @@ -256,7 +257,7 @@ "source": [ "np.testing.assert_approx_equal(\n", " catboost_zero_one_loss - base_zero_one_loss,\n", - " - np.sum(sage_values.values),\n", + " -np.sum(sage_values.values),\n", " significant=2,\n", ")\n", "np.testing.assert_approx_equal(\n", @@ -324,7 +325,7 @@ "outputs": [], "source": [ "np.testing.assert_approx_equal(\n", - " catboost_log_loss - base_log_loss, - np.sum(sage_values.values), significant=2\n", + " catboost_log_loss - base_log_loss, -np.sum(sage_values.values), significant=2\n", ")" ] }, diff --git a/notebooks/mnist.ipynb b/notebooks/mnist.ipynb index ee0bdee..5089ea9 100644 --- a/notebooks/mnist.ipynb +++ b/notebooks/mnist.ipynb @@ -24,7 +24,7 @@ "outputs": [], "source": [ "# Load train set\n", - "train = dsets.MNIST('../data', train=True, download=True)\n", + "train = dsets.MNIST(\"../data\", train=True, download=True)\n", "imgs = train.data.reshape(-1, 784) / 255.0\n", "labels = train.targets\n", "\n", @@ -36,7 +36,7 @@ "train, Y_train = imgs[6000:], labels[6000:]\n", "\n", "# Load test set\n", - "test = dsets.MNIST('../data', train=False, download=True)\n", + "test = dsets.MNIST(\"../data\", train=False, download=True)\n", "test, Y_test = test.data.reshape(-1, 784) / 255.0, test.targets\n", "\n", "# Move test data to numpy\n", @@ -57,11 +57,12 @@ "metadata": {}, "outputs": [], "source": [ + "from copy import deepcopy\n", + "\n", "import numpy as np\n", "import torch.nn as nn\n", "import torch.optim as optim\n", - "from copy import deepcopy\n", - "from torch.utils.data import TensorDataset, DataLoader" + "from torch.utils.data import DataLoader, TensorDataset" ] }, { @@ -73,13 +74,10 @@ "outputs": [], "source": [ "# Create model\n", - "device = torch.device('cuda', 1)\n", + "device = torch.device(\"cuda\", 1)\n", "model = nn.Sequential(\n", - " nn.Linear(784, 256),\n", - " nn.ELU(),\n", - " nn.Linear(256, 256),\n", - " nn.ELU(),\n", - " nn.Linear(256, 10)).to(device)\n", + " nn.Linear(784, 256), nn.ELU(), nn.Linear(256, 256), nn.ELU(), nn.Linear(256, 10)\n", + ").to(device)\n", "\n", "# Training parameters\n", "lr = 1e-3\n", @@ -124,8 +122,8 @@ " # Calculate validation loss.\n", " val_loss = loss_fn(model(val), Y_val).item()\n", " if verbose:\n", - " print('{}Epoch = {}{}'.format('-' * 10, epoch + 1, '-' * 10))\n", - " print('Val loss = {:.4f}'.format(val_loss))\n", + " print(\"{}Epoch = {}{}\".format(\"-\" * 10, epoch + 1, \"-\" * 10))\n", + " print(\"Val loss = {:.4f}\".format(val_loss))\n", "\n", " # Check convergence criterion.\n", " if val_loss < min_criterion:\n", @@ -134,7 +132,7 @@ " best_model = deepcopy(model)\n", " elif (epoch - min_epoch) == lookback:\n", " if verbose:\n", - " print('Stopping early')\n", + " print(\"Stopping early\")\n", " break\n", "\n", "# Keep best model\n", @@ -170,8 +168,8 @@ "base_ce = loss_fn(torch.log(p.repeat(len(Y_test), 1)), Y_test)\n", "ce = loss_fn(model(test), Y_test)\n", "\n", - "print('Base rate cross entropy = {:.4f}'.format(base_ce))\n", - "print('Model cross entropy = {:.4f}'.format(ce))" + "print(\"Base rate cross entropy = {:.4f}\".format(base_ce))\n", + "print(\"Model cross entropy = {:.4f}\".format(ce))" ] }, { @@ -189,8 +187,9 @@ "metadata": {}, "outputs": [], "source": [ - "import sage\n", - "import matplotlib.pyplot as plt" + "import matplotlib.pyplot as plt\n", + "\n", + "import sage" ] }, { @@ -206,7 +205,7 @@ "for i in range(num_superpixels):\n", " for j in range(num_superpixels):\n", " img = np.zeros((28, 28), dtype=int)\n", - " img[width*i:width*(i+1), width*j:width*(j+1)] = 1\n", + " img[width * i : width * (i + 1), width * j : width * (j + 1)] = 1\n", " img = img.reshape((784,))\n", " groups.append(np.where(img)[0])" ] @@ -248,7 +247,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.GroupedMarginalImputer(model_activation, test_np[:512], groups)\n", - "estimator = sage.PermutationEstimator(imputer, 'cross entropy')\n", + "estimator = sage.PermutationEstimator(imputer, \"cross entropy\")\n", "sage_values = estimator(test_np, Y_test_np, batch_size=128, thresh=0.05)" ] }, @@ -272,8 +271,7 @@ "# Plot\n", "plt.figure(figsize=(6, 6))\n", "m = np.max(np.abs(sage_values.values))\n", - "plt.imshow(- sage_values.values.reshape(7, 7),\n", - " cmap='seismic', vmin=-m, vmax=m)\n", + "plt.imshow(-sage_values.values.reshape(7, 7), cmap=\"seismic\", vmin=-m, vmax=m)\n", "plt.xticks([])\n", "plt.yticks([])\n", "plt.show()" @@ -301,7 +299,7 @@ "for i in range(num_superpixels):\n", " for j in range(num_superpixels):\n", " img = np.zeros((28, 28), dtype=int)\n", - " img[width*i:width*(i+1), width*j:width*(j+1)] = 1\n", + " img[width * i : width * (i + 1), width * j : width * (j + 1)] = 1\n", " img = img.reshape((784,))\n", " groups.append(np.where(img)[0])" ] @@ -343,7 +341,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.GroupedMarginalImputer(model_activation, test_np[:512], groups)\n", - "estimator = sage.PermutationEstimator(imputer, 'cross entropy')\n", + "estimator = sage.PermutationEstimator(imputer, \"cross entropy\")\n", "sage_values = estimator(test_np, Y_test_np, batch_size=128, thresh=0.05)" ] }, @@ -367,8 +365,7 @@ "# Plot\n", "plt.figure(figsize=(6, 6))\n", "m = np.max(np.abs(sage_values.values))\n", - "plt.imshow(- sage_values.values.reshape(14, 14),\n", - " cmap='seismic', vmin=-m, vmax=m)\n", + "plt.imshow(-sage_values.values.reshape(14, 14), cmap=\"seismic\", vmin=-m, vmax=m)\n", "plt.xticks([])\n", "plt.yticks([])\n", "plt.show()" @@ -420,7 +417,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model_activation, test_np[:512])\n", - "estimator = sage.PermutationEstimator(imputer, 'cross entropy')\n", + "estimator = sage.PermutationEstimator(imputer, \"cross entropy\")\n", "sage_values = estimator(test_np, Y_test_np, batch_size=128, thresh=0.05)" ] }, @@ -444,8 +441,7 @@ "# Plot\n", "plt.figure(figsize=(6, 6))\n", "m = np.max(np.abs(sage_values.values))\n", - "plt.imshow(- sage_values.values.reshape(28, 28),\n", - " cmap='seismic', vmin=-m, vmax=m)\n", + "plt.imshow(-sage_values.values.reshape(28, 28), cmap=\"seismic\", vmin=-m, vmax=m)\n", "plt.xticks([])\n", "plt.yticks([])\n", "plt.show()" @@ -497,7 +493,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.MarginalImputer(model_activation, test_np[:128])\n", - "estimator = sage.PermutationEstimator(imputer, 'cross entropy')\n", + "estimator = sage.PermutationEstimator(imputer, \"cross entropy\")\n", "sage_values = estimator(test_np, Y_test_np, batch_size=512, thresh=0.05)" ] }, @@ -521,8 +517,7 @@ "# Plot\n", "plt.figure(figsize=(6, 6))\n", "m = np.max(np.abs(sage_values.values))\n", - "plt.imshow(- sage_values.values.reshape(28, 28),\n", - " cmap='seismic', vmin=-m, vmax=m)\n", + "plt.imshow(-sage_values.values.reshape(28, 28), cmap=\"seismic\", vmin=-m, vmax=m)\n", "plt.xticks([])\n", "plt.yticks([])\n", "plt.show()" @@ -574,7 +569,7 @@ "source": [ "# Setup and calculate\n", "imputer = sage.DefaultImputer(model_activation, np.zeros(784))\n", - "estimator = sage.PermutationEstimator(imputer, 'cross entropy')\n", + "estimator = sage.PermutationEstimator(imputer, \"cross entropy\")\n", "sage_values = estimator(test_np, Y_test_np, batch_size=512, thresh=0.05)" ] }, @@ -598,8 +593,7 @@ "# Plot\n", "plt.figure(figsize=(6, 6))\n", "m = np.max(np.abs(sage_values.values))\n", - "plt.imshow(- sage_values.values.reshape(28, 28),\n", - " cmap='seismic', vmin=-m, vmax=m)\n", + "plt.imshow(-sage_values.values.reshape(28, 28), cmap=\"seismic\", vmin=-m, vmax=m)\n", "plt.xticks([])\n", "plt.yticks([])\n", "plt.show()" diff --git a/sage/iterated_estimator.py b/sage/iterated_estimator.py index d3c33b5..83002ea 100644 --- a/sage/iterated_estimator.py +++ b/sage/iterated_estimator.py @@ -200,9 +200,7 @@ def __call__( # Print progress message. if verbose: if detect_convergence: - print( - f"StdDev Ratio = {ratio:.4f} " f"(Converge at {thresh:.4f})" - ) + print(f"StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})") else: print("StdDev Ratio = {:.4f}".format(ratio)) diff --git a/sage/kernel_estimator.py b/sage/kernel_estimator.py index 1da6228..0df49b4 100644 --- a/sage/kernel_estimator.py +++ b/sage/kernel_estimator.py @@ -226,9 +226,7 @@ def __call__( # Print progress message. if verbose: if detect_convergence: - print( - f"StdDev Ratio = {ratio:.4f} " f"(Converge at {thresh:.4f})" - ) + print(f"StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})") else: print(f"StdDev Ratio = {ratio:.4f}") diff --git a/sage/permutation_estimator.py b/sage/permutation_estimator.py index 58557f6..bb8660e 100644 --- a/sage/permutation_estimator.py +++ b/sage/permutation_estimator.py @@ -133,7 +133,7 @@ def __call__( # Print progress message. if verbose: if detect_convergence: - print(f"StdDev Ratio = {ratio:.4f} " f"(Converge at {thresh:.4f})") + print(f"StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})") else: print(f"StdDev Ratio = {ratio:.4f}") diff --git a/sage/utils.py b/sage/utils.py index 0acddf6..4e340b2 100644 --- a/sage/utils.py +++ b/sage/utils.py @@ -118,7 +118,7 @@ def verify_model_data(imputer, X, Y, loss, batch_size): Y[Y == -1] = 0 else: raise ValueError( - "labels for binary classification must be " "[0, 1] or [-1, 1]" + "labels for binary classification must be [0, 1] or [-1, 1]" ) # Check for valid probability outputs. @@ -208,7 +208,7 @@ def __call__(self, pred, target): target = np.expand_dims(target, -1) elif not target.shape == pred.shape: raise ValueError( - "shape mismatch, pred has shape {} and target " "has shape {}".format( + "shape mismatch, pred has shape {} and target has shape {}".format( pred.shape, target.shape ) )