Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XGBoost autologging: support per-class importance plots #4523

Merged
merged 6 commits into from
Jul 6, 2021

Conversation

dbczumar
Copy link
Collaborator

@dbczumar dbczumar commented Jul 1, 2021

What changes are proposed in this pull request?

XGBoost 1.15.0-dev introduced support for importance computation on linear estimators. These estimators return importance values for each (feature, class) pair as a num_features-by-num_classes matrix. This PR introduces extends feature importance plotting support in XGBoost autologging to handle this new importance value format.

How is this patch tested?

  • Unit tests
  • Manual tests on XGBoost 1.15.0-dev:
  1. Linear booster training on Iris
from sklearn.datasets import load_iris

iris = load_iris()

import mlflow
mlflow.xgboost.autolog()


import xgboost as xgb
dtrain = xgb.DMatrix(iris.data, label=iris.target)

bst_params = {"objective": "multi:softprob", "num_class": 3, "booster": "gblinear"}
model = xgb.train(bst_params, dtrain)

feature_importance_weight

  1. Tree booster training on Iris
from sklearn.datasets import load_iris

iris = load_iris()

import mlflow
mlflow.xgboost.autolog()


import xgboost as xgb
dtrain = xgb.DMatrix(iris.data, label=iris.target)

bst_params = {
        "objective": "multi:softprob",
        "num_class": 10,
        "eval_metric": "mlogloss",
        "booster": "gbtree",
    }
model = xgb.train(bst_params, dtrain)

feature_importance_weight

  1. Linear booster training on MNIST
from sklearn.datasets import load_digits

digits = load_digits()

import mlflow
mlflow.xgboost.autolog()


import xgboost as xgb
dtrain = xgb.DMatrix(digits.data, label=digits.target)

bst_params = {"objective": "multi:softprob", "num_class": 10, "booster": "gblinear"}
model = xgb.train(bst_params, dtrain)

feature_importance_weight

  1. Tree booster training on MNIST
from sklearn.datasets import load_digits

digits = load_digits()

import mlflow
mlflow.xgboost.autolog()


import xgboost as xgb
dtrain = xgb.DMatrix(digits.data, label=digits.target)

bst_params = {
        "objective": "multi:softprob",
        "num_class": 10,
        "eval_metric": "mlogloss",
        "booster": "gbtree",
    }
model = xgb.train(bst_params, dtrain)

feature_importance_weight

Release Notes

Add XGBoost autologging support for multi-class feature importance plots

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/projects: MLproject format, project running backends
  • area/scoring: Local serving, model deployment tools, spark UDFs
  • area/server-infra: MLflow server, JavaScript dev server
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, JavaScript, plotting
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

@github-actions github-actions bot added area/tracking Tracking service, tracking client APIs, autologging rn/feature Mention under Features in Changelogs. labels Jul 1, 2021
dbczumar added 5 commits July 1, 2021 15:32
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
@dbczumar dbczumar force-pushed the linear_importance_plot branch from 0588f67 to f3ce845 Compare July 1, 2021 22:32
@dbczumar dbczumar requested a review from harupy July 1, 2021 22:32
Comment on lines 450 to 452
importances_per_class_by_feature = np.array(
[[importance] for importance in importances_per_class_by_feature]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
importances_per_class_by_feature = np.array(
[[importance] for importance in importances_per_class_by_feature]
)
importances_per_class_by_feature = np.array(
[[importance] for importance in importances_per_class_by_feature[indices]]
)

Can we sort importance as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! Done! Here's a screenshot from tree booster on MNIST:

feature_importance_weight

feature_yloc + offset,
class_importance,
align="center",
height=(0.5 / num_classes),
Copy link
Member

@harupy harupy Jul 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
height=(0.5 / num_classes),
height=(0.5 / max(num_classes- 1, 1)),
# alternative approaches
height=(0.5 / (num_classes- 1 if num_clalles > 1 else 1)),
height=(0.5 / (num_clalles - 1 or 1)),

Can we divide by num_classes - 1 to remove the gap between bars?

feature_importance_weight

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion! Done!

for class_idx, (offset, class_importance) in enumerate(
zip(offsets_per_yloc, importances_per_class)
):
(bar,) = ax.barh(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice unpacking :)

Comment on lines 457 to 458
else:
label_classes_on_plot = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we sort a 2D importance matrix (that linear boosters generates) as well?

import numpy as np

features = np.array(["a", "b", "c"])
importance = [
    # class0, class1, class2
    [7, 8, 9],  # a
    [4, 5, 6],  # b
    [1, 2, 3],  # c
]
importances_per_class_by_feature = np.array(importance)
abs_sum = np.abs(importances_per_class_by_feature).sum(axis=1)
# or abs_mean = np.abs(importances_per_class_by_feature).mean(axis=1)
indices = np.argsort(abs_sum)

print(importances_per_class_by_feature[indices])
# [[1 2 3]
#  [4 5 6]
#  [7 8 9]]

print(features[indices])
# ['c' 'b' 'a']

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely! Done! (Chose sum() for the magnitude metric rather than mean()). Here's a screenshot from a linear booster on MNIST:

feature_importance_weight

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Copy link
Collaborator Author

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harupy Thanks for the awesome review feedback! I've addressed your comments. Can you take another look?

feature_yloc + offset,
class_importance,
align="center",
height=(0.5 / num_classes),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion! Done!

Comment on lines 450 to 452
importances_per_class_by_feature = np.array(
[[importance] for importance in importances_per_class_by_feature]
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! Done! Here's a screenshot from tree booster on MNIST:

feature_importance_weight

Comment on lines 457 to 458
else:
label_classes_on_plot = True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely! Done! (Chose sum() for the magnitude metric rather than mean()). Here's a screenshot from a linear booster on MNIST:

feature_importance_weight

Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@dbczumar dbczumar merged commit e0e7181 into mlflow:master Jul 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/tracking Tracking service, tracking client APIs, autologging rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants