Skip to content

Commit

Permalink
chore: update security and exaplainability
Browse files Browse the repository at this point in the history
  • Loading branch information
crismunoz committed Jul 23, 2024
1 parent 53865d6 commit eb90f6c
Show file tree
Hide file tree
Showing 74 changed files with 5,599 additions and 2,241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from holisticai.bias.mitigation.inprocessing.adversarial_debiasing.models import (
ADModel,
AdversarialModel,
Expand All @@ -20,6 +21,10 @@
logger = logging.getLogger(__name__)


def is_numeric(df):
return all(pd.api.types.is_numeric_dtype(df[col]) for col in df.columns)


class AdversarialDebiasing(BMImp):
"""Adversarial Debiasing
Expand Down Expand Up @@ -176,6 +181,9 @@ def fit(

params = self._load_data(X=X, y=y, group_a=group_a, group_b=group_b)
x = pd.DataFrame(params["X"])
if not is_numeric(x):
raise ValueError("Adversarial Debiasing only works with numeric features.")

y = pd.Series(params["y"])
group_a = pd.Series(params["group_a"])
group_b = pd.Series(params["group_b"])
Expand Down Expand Up @@ -239,6 +247,8 @@ def predict(self, X):
np.ndarray: Predicted output per sample.
"""
if not is_numeric(X):
raise ValueError("Adversarial Debiasing only works with numeric features.")
p = self.predict_proba(X)
return np.argmax(p, axis=1).ravel()

Expand All @@ -260,6 +270,9 @@ def predict_proba(self, X):
np.ndarray: Predicted matrix probability per sample.
"""
if not is_numeric(X):
raise ValueError("Adversarial Debiasing only works with numeric features.")

proba = np.empty((X.shape[0], 2))
proba[:, 1] = self._predict_proba(X)
proba[:, 0] = 1.0 - proba[:, 1]
Expand All @@ -283,5 +296,8 @@ def predict_score(self, X):
np.ndarray: Predicted probability per sample.
"""
if not is_numeric(X):
raise ValueError("Adversarial Debiasing only works with numeric features.")

p = self._predict(X).reshape([-1])
return p
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from holisticai.bias.mitigation.inprocessing.commons import Logging


Expand Down Expand Up @@ -30,7 +31,7 @@ def loss(self, coef_, X, y, groups):
loss : float
loss function value
"""
coef = coef_.reshape(self.estimator.nb_group_values, self.estimator.nb_features)
coef = coef_.reshape(self.estimator.nb_group_values, self.estimator.nb_features).astype(np.float64)
X = self.estimator.preprocessing_data(X)
sigma = self.estimator.sigmoid(X=X, groups=groups, coef=coef)
loss = self.loss_fn(y=y, sigma=sigma, groups=groups, coef=coef)
Expand All @@ -53,7 +54,7 @@ def grad_loss(self, coef_, X, y, groups):
first derivative of loss function
"""

coef = coef_.reshape(self.estimator.nb_group_values, self.estimator.nb_features)
coef = coef_.reshape(self.estimator.nb_group_values, self.estimator.nb_features).astype(np.float64)
X = self.estimator.preprocessing_data(X)
sigma = self.estimator.sigmoid(X=X, groups=groups, coef=coef)
return self.loss_fn.gradient(X=X, y=y, sigma=sigma, groups=groups, coef=coef)
Expand Down
13 changes: 7 additions & 6 deletions src/holisticai/datasets/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

import pandas as pd
from sklearn.model_selection import train_test_split
Expand All @@ -11,6 +11,7 @@
from collections.abc import Iterable

import numpy as np
from numpy.random import RandomState


class DatasetDict(dict):
Expand Down Expand Up @@ -170,7 +171,7 @@ def __init__(self, dataset, batch_size, dtype):
def batched(self):
def batch_generator(batch_size):
for i in range(self.num_batches):
batch = Dataset(data=self.dataset.data.iloc[i * batch_size : (i + 1) * batch_size])
batch = Dataset(self.dataset.data.iloc[i * batch_size : (i + 1) * batch_size])
yield batch

if self.dtype == "jax":
Expand Down Expand Up @@ -242,8 +243,8 @@ def __update_metadata(self):
features_counts = features_values.value_counts()
self.features_is_series = {key: (value == 1) for key, value in features_counts.items()}

def __init__(self, data: pd.DataFrame | None = None, **kargs):
if data is None:
def __init__(self, _data: pd.DataFrame | None = None, **kargs):
if _data is None:
self.data = {}
for name, value in kargs.items():
if isinstance(value, pd.DataFrame):
Expand All @@ -257,7 +258,7 @@ def __init__(self, data: pd.DataFrame | None = None, **kargs):
self.data.columns = self.data.columns.set_names(["features", "subfeatures"])
self.data.reset_index(drop=True)
else:
self.data = data.reset_index(drop=True)
self.data = _data.reset_index(drop=True)
self.__update_metadata()
self.random_state = np.random.RandomState()

Expand Down Expand Up @@ -419,7 +420,7 @@ def apply_fn_to_multilevel_df(df, fn):
return result_df


def sample_n(group, n, random_state=None):
def sample_n(group: pd.DataFrame, n: int, random_state: Union[RandomState, None] = None) -> pd.DataFrame:
if len(group) < n:
return group
return group.sample(n=n, replace=False, random_state=random_state)
Loading

0 comments on commit eb90f6c

Please sign in to comment.