Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: generate valid EBMModel when merging #578

Draft
wants to merge 18 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
BUG: correctly clean up exclude attribute
  • Loading branch information
Weh Andreas committed Jan 4, 2025
commit cda91a70dca45509903ac5d62c970978b7f72e20
15 changes: 9 additions & 6 deletions python/interpret-core/interpret/glassbox/_ebm/_merge_ebms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EBMModel,
ExplainableBoostingClassifier,
ExplainableBoostingRegressor,
_clean_exclude,
is_private,
)
from ._utils import (
Expand Down Expand Up @@ -450,14 +451,16 @@ def _initialize_ebm(models: List[EBMModel], ebm_type=EBMModel) -> EBMModel:
# none of the models contains all feature_idxs
# merged EBM should exclude features included by none of the models
# -> overlap of all features
# following algorithm works only if all models use the same feature names
# interactions are probably not handled correctly
excluded = {(feat,) for feat in models[0].feature_names_in_}
clean_excludes = []
for model in models:
if model.exclude == "mains":
excluded &= set(models[0].feature_names_in_)
clean_excludes.append({(idx,) for idx in range(model.n_features_in_)})
continue
excluded &= set(model.exclude)
feature_map = {
name: idx for idx, name in enumerate(model.feature_names_in_)
}
clean_excludes.append(_clean_exclude(model.exclude, feature_map))
excluded = set.intersection(*clean_excludes)
manual_kdws["exclude"] = list(excluded) if excluded else None

# handle `interactions`
Expand Down Expand Up @@ -498,7 +501,6 @@ def monotone(args) -> int:
for item in zip(*(model.monotone_constraints for model in models))
]

# TODO: treat special cases: exclude, interactions
for key in kdws:
values = np.array([getattr(model, key, np.nan) for model in models])
nan_weight = np.copy(weights)
Expand Down Expand Up @@ -544,6 +546,7 @@ def merge_ebms(models):
if any(not getattr(model, "has_fitted_", False) for model in models):
msg = "All models must be fitted."
raise Exception(msg)

ebm = _initialize_ebm(models, ebm_type=ebm_type)
ebm.has_fitted_ = True

Expand Down
51 changes: 51 additions & 0 deletions python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,54 @@ def test_unfitted():
# ebm2 is not fitted
with pytest.raises(Exception, match="All models must be fitted."):
merge_ebms([ebm1, ebm2])


@pytest.mark.skip(reason="Extremely slow")
def test_merge_monotone():
"""Check merging of features with `monotone_constraints`."""
X, y, names, _ = make_synthetic(classes=2, missing=True, output_type="str")
TestEBM = partial(
ExplainableBoostingClassifier,
feature_names=names,
random_state=42,
**_fast_kwds,
)
ebm1 = TestEBM(monotone_constraints=[+0, +0, +0, +1, +1, -1, 0, 0, 0, 0])
ebm1.fit(X, y)
ebm2 = TestEBM(monotone_constraints=[+0, +1, -1, +1, -1, -1, 0, 0, 0, 0])
ebm2.fit(X, y)
merged_ebm = merge_ebms([ebm1, ebm2])
assert merged_ebm.monotone_constraints == [+0, +0, +0, +1, 0, -1, 0, 0, 0, 0]
ebm3 = TestEBM(monotone_constraints=None)
ebm3.fit(X, y)
merged_ebm = merge_ebms([ebm1, ebm2, ebm3])
assert merged_ebm.monotone_constraints is None


def test_merge_exclude():
"""Check merging of features with `exclude`."""
X, y, names, _ = make_synthetic(classes=2, missing=True, output_type="str")
TestEBM = partial(
ExplainableBoostingClassifier,
feature_names=names,
random_state=42,
**_fast_kwds,
)
ebm1 = TestEBM(exclude=None)
ebm1.fit(X, y)
ebm2 = TestEBM(exclude=[0, 1, 2])
ebm2.fit(X, y)
merged_ebm = merge_ebms([ebm1, ebm2])
assert merged_ebm.exclude is None
ebm1 = TestEBM(exclude=[0, 2])
ebm1.fit(X, y)
ebm2 = TestEBM(exclude=[0, 1, 2])
ebm2.fit(X, y)
merged_ebm = merge_ebms([ebm1, ebm2])
assert merged_ebm.exclude == [(0,), (2,)]
ebm1 = TestEBM(exclude="mains")
ebm1.fit(X, y)
ebm2 = TestEBM(exclude=[0, 1, 2])
ebm2.fit(X, y)
merged_ebm = merge_ebms([ebm1, ebm2])
assert merged_ebm.exclude == [(0,), (1,), (2,)]
Loading