Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: generate valid EBMModel when merging #578

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MAINT: split merge_ebms in more functions
Signed-off-by: DerWeh <andreas.weh@web.de>
  • Loading branch information
Weh Andreas authored and DerWeh committed Oct 19, 2024
commit 90e20d0a8b035c3e5e1a5dcc6b6422ec0ce1e4b3
245 changes: 126 additions & 119 deletions python/interpret-core/interpret/glassbox/_ebm/_merge_ebms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import warnings
from itertools import chain, count
from math import isnan
from typing import List

import numpy as np

from ...utils._native import Native
from ._ebm import EBMModel
from ._utils import (
convert_categorical_to_continuous,
deduplicate_bins,
Expand Down Expand Up @@ -282,154 +284,49 @@ def _harmonize_tensor(
return new_tensor.reshape(new_shape)


def merge_ebms(models):
"""Merge EBM models trained on similar datasets that have the same set of features.

Args:
models: List of EBM models to be merged.

Returns:
An EBM model with averaged mean and standard deviation of input models.
"""
if len(models) == 0: # pragma: no cover
msg = "0 models to merge."
raise Exception(msg)

model_types = list(set(map(type, models)))
if len(model_types) == 2:
type_names = [model_type.__name__ for model_type in model_types]
if (
"ExplainableBoostingClassifier" in type_names
and "DPExplainableBoostingClassifier" in type_names
):
ebm_type = model_types[type_names.index("ExplainableBoostingClassifier")]
is_classification = True
is_dp = False
elif (
"ExplainableBoostingRegressor" in type_names
and "DPExplainableBoostingRegressor" in type_names
):
ebm_type = model_types[type_names.index("ExplainableBoostingRegressor")]
is_classification = False
is_dp = False
else:
msg = "Inconsistent model types attempting to be merged."
raise Exception(msg)
elif len(model_types) == 1:
ebm_type = model_types[0]
if ebm_type.__name__ == "ExplainableBoostingClassifier":
is_classification = True
is_dp = False
elif ebm_type.__name__ == "DPExplainableBoostingClassifier":
is_classification = True
is_dp = True
elif ebm_type.__name__ == "ExplainableBoostingRegressor":
is_classification = False
is_dp = False
elif ebm_type.__name__ == "DPExplainableBoostingRegressor":
is_classification = False
is_dp = True
else:
msg = f"Invalid EBM model type {ebm_type.__name__} attempting to be merged."
raise Exception(msg)
else:
msg = "Inconsistent model types being merged."
raise Exception(msg)

# TODO: create the ExplainableBoostingClassifier etc, type directly
# by name instead of using __new__ from ebm_type
ebm = ebm_type.__new__(ebm_type)

if any(
not getattr(model, "has_fitted_", False) for model in models
): # pragma: no cover
msg = "All models must be fitted."
raise Exception(msg)
ebm.has_fitted_ = True

link = models[0].link_
if any(model.link_ != link for model in models):
msg = "Models with different link functions cannot be merged"
raise Exception(msg)
ebm.link_ = link

link_param = models[0].link_param_
if isnan(link_param):
if not all(isnan(model.link_param_) for model in models):
msg = "Models with different link param values cannot be merged"
raise Exception(msg)
elif any(model.link_param_ != link_param for model in models):
msg = "Models with different link param values cannot be merged"
raise Exception(msg)
ebm.link_param_ = link_param

def _assert_model_compatibility(models: List[EBMModel]) -> None:
"""Check if models can be merged, raise error if not."""
# self.bins_ is the only feature based attribute that we absolutely require
n_features = len(models[0].bins_)

for model in models:
if n_features != len(model.bins_): # pragma: no cover
msg = "Inconsistent numbers of features in the models."
raise Exception(msg)

feature_names_in = getattr(model, "feature_names_in_", None)
if feature_names_in is not None and n_features != len(
feature_names_in
if hasattr(model, "feature_names_in_") and n_features != len(
model.feature_names_in_
): # pragma: no cover
msg = "Inconsistent numbers of features in the models."
raise Exception(msg)

feature_types_in = getattr(model, "feature_types_in_", None)
if feature_types_in is not None and n_features != len(
feature_types_in
if hasattr(model, "feature_types_in_") and n_features != len(
model.feature_types_in_
): # pragma: no cover
msg = "Inconsistent numbers of features in the models."
raise Exception(msg)

feature_bounds = getattr(model, "feature_bounds_", None)
if (
feature_bounds is not None and n_features != feature_bounds.shape[0]
hasattr(model, "feature_bounds_")
and n_features != model.feature_bounds_.shape[0]
): # pragma: no cover
msg = "Inconsistent numbers of features in the models."
raise Exception(msg)

histogram_weights = getattr(model, "histogram_weights_", None)
if histogram_weights is not None and n_features != len(
histogram_weights
if hasattr(model, "histogram_weights_") and n_features != len(
model.histogram_weights_
): # pragma: no cover
msg = "Inconsistent numbers of features in the models."
raise Exception(msg)

unique_val_counts = getattr(model, "unique_val_counts_", None)
if unique_val_counts is not None and n_features != len(
unique_val_counts
if hasattr(model, "unique_val_counts_") and n_features != len(
model.unique_val_counts_
): # pragma: no cover
msg = "Inconsistent numbers of features in the models."
raise Exception(msg)

old_bounds = []
old_mapping = []
old_bins = []
for model in models:
if any(len(set(map(type, bin_levels))) != 1 for bin_levels in model.bins_):
msg = "Inconsistent bin types within a model."
raise Exception(msg)

feature_bounds = getattr(model, "feature_bounds_", None)
if feature_bounds is None:
old_bounds.append(None)
else:
old_bounds.append(feature_bounds.copy())

old_mapping.append([[] for _ in range(n_features)])
old_bins.append([[] for _ in range(n_features)])

# TODO: every time we merge models we fragment the bins more and more and this is undesirable
# especially for pairs. When we build models, we store the feature bin cuts for pairs even
# if we have no pairs that use that paritcular feature as a pair. We can eliminate these useless
# pair feature cuts before merging the bins and that'll give us less resulting cuts. Having less
# cuts reduces the number of estimates that we need to make and reduces the complexity of the
# tensors, so it's good to have this reduction.

def _get_new_bins(models: List[EBMModel], *, old_mapping, old_bins, old_bounds):
n_features = len(models[0].bins_)
new_feature_types = []
new_bins = []
for feature_idx in range(n_features):
Expand Down Expand Up @@ -519,6 +416,116 @@ def merge_ebms(models):
)
new_leveled_bins.append(merged_bins)
new_bins.append(new_leveled_bins)
return new_bins, new_feature_types


def merge_ebms(models):
"""Merge EBM models trained on similar datasets that have the same set of features.

Args:
models: List of EBM models to be merged.

Returns:
An EBM model with averaged mean and standard deviation of input models.
"""
if len(models) == 0: # pragma: no cover
msg = "0 models to merge."
raise Exception(msg)

model_types = list(set(map(type, models)))
if len(model_types) == 2:
type_names = [model_type.__name__ for model_type in model_types]
if (
"ExplainableBoostingClassifier" in type_names
and "DPExplainableBoostingClassifier" in type_names
):
ebm_type = model_types[type_names.index("ExplainableBoostingClassifier")]
is_classification = True
is_dp = False
elif (
"ExplainableBoostingRegressor" in type_names
and "DPExplainableBoostingRegressor" in type_names
):
ebm_type = model_types[type_names.index("ExplainableBoostingRegressor")]
is_classification = False
is_dp = False
else:
msg = "Inconsistent model types attempting to be merged."
raise Exception(msg)
elif len(model_types) == 1:
ebm_type = model_types[0]
if ebm_type.__name__ == "ExplainableBoostingClassifier":
is_classification = True
is_dp = False
elif ebm_type.__name__ == "DPExplainableBoostingClassifier":
is_classification = True
is_dp = True
elif ebm_type.__name__ == "ExplainableBoostingRegressor":
is_classification = False
is_dp = False
elif ebm_type.__name__ == "DPExplainableBoostingRegressor":
is_classification = False
is_dp = True
else:
msg = f"Invalid EBM model type {ebm_type.__name__} attempting to be merged."
raise Exception(msg)
else:
msg = "Inconsistent model types being merged."
raise Exception(msg)

# TODO: create the ExplainableBoostingClassifier etc, type directly
# by name instead of using __new__ from ebm_type
ebm = ebm_type.__new__(ebm_type)

if any(
not getattr(model, "has_fitted_", False) for model in models
): # pragma: no cover
msg = "All models must be fitted."
raise Exception(msg)
ebm.has_fitted_ = True

link = models[0].link_
if any(model.link_ != link for model in models):
msg = "Models with different link functions cannot be merged"
raise Exception(msg)
ebm.link_ = link

link_param = models[0].link_param_
if isnan(link_param):
if not all(isnan(model.link_param_) for model in models):
msg = "Models with different link param values cannot be merged"
raise Exception(msg)
elif any(model.link_param_ != link_param for model in models):
msg = "Models with different link param values cannot be merged"
raise Exception(msg)
ebm.link_param_ = link_param

# self.bins_ is the only feature based attribute that we absolutely require
n_features = len(models[0].bins_)

_assert_model_compatibility(models)

old_mapping = [[[] for _ in range(n_features)] for _ in models]
old_bins = [[[] for _ in range(n_features)] for _ in models]
old_bounds = []
for model in models:
if any(len(set(map(type, bin_levels))) != 1 for bin_levels in model.bins_):
msg = "Inconsistent bin types within a model."
raise Exception(msg)

feature_bounds = getattr(model, "feature_bounds_", None)
old_bounds.append(None if feature_bounds is None else feature_bounds.copy())

# TODO: every time we merge models we fragment the bins more and more and this is undesirable
# especially for pairs. When we build models, we store the feature bin cuts for pairs even
# if we have no pairs that use that particular feature as a pair. We can eliminate these useless
# pair feature cuts before merging the bins and that'll give us less resulting cuts. Having less
# cuts reduces the number of estimates that we need to make and reduces the complexity of the
# tensors, so it's good to have this reduction.

new_bins, new_feature_types = _get_new_bins(
models, old_mapping=old_mapping, old_bins=old_bins, old_bounds=old_bounds
)
ebm.feature_types_in_ = new_feature_types
deduplicate_bins(new_bins)
ebm.bins_ = new_bins
Expand Down