Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo/PoC only - xgboost/generalized caching support for model_selection #606

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion dask_ml/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def build_graph(
return_train_score=_RETURN_TRAIN_SCORE_DEFAULT,
cache_cv=True,
multimetric=False,
extract_fn=None # XXX need to update callers?
):
# This is provided for compatibility with TPOT. Remove
# once TPOT is updated and requires a dask-ml>=0.13.0
Expand All @@ -151,6 +152,7 @@ def decompress_params(fields, params):
error_score=error_score,
return_train_score=return_train_score,
cache_cv=cache_cv,
extract_fn=extract_fn,
)
cv_name = "cv-split-" + main_token
if iid:
Expand Down Expand Up @@ -193,6 +195,7 @@ def build_cv_graph(
error_score="raise",
return_train_score=_RETURN_TRAIN_SCORE_DEFAULT,
cache_cv=True,
extract_fn=None,
):
X, y, groups = to_indexable(X, y, groups)
cv = check_cv(cv, y, is_classifier(estimator))
Expand Down Expand Up @@ -220,7 +223,8 @@ def build_cv_graph(
)

cv_name = "cv-split-" + main_token
dsk[cv_name] = (cv_split, cv, X_name, y_name, groups_name, is_pairwise, cache_cv)
dsk[cv_name] = (cv_split, cv, X_name, y_name, groups_name, is_pairwise,
cache_cv, extract_fn)

if iid:
weights = "cv-n-samples-" + main_token
Expand Down Expand Up @@ -1066,6 +1070,7 @@ def __init__(
scheduler=None,
n_jobs=-1,
cache_cv=True,
extract_fn=None,
):
self.scoring = scoring
self.estimator = estimator
Expand All @@ -1077,6 +1082,7 @@ def __init__(
self.scheduler = scheduler
self.n_jobs = n_jobs
self.cache_cv = cache_cv
self.extract_fn = extract_fn

def _check_if_refit(self, attr):
if not self.refit:
Expand Down Expand Up @@ -1231,6 +1237,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
error_score=error_score,
return_train_score=self.return_train_score,
cache_cv=self.cache_cv,
extract_fn=self.extract_fn
)

n_jobs = _normalize_n_jobs(self.n_jobs)
Expand Down Expand Up @@ -1564,6 +1571,7 @@ def __init__(
scheduler=None,
n_jobs=-1,
cache_cv=True,
extract_fn=None
):
super(GridSearchCV, self).__init__(
estimator=estimator,
Expand All @@ -1576,6 +1584,7 @@ def __init__(
scheduler=scheduler,
n_jobs=n_jobs,
cache_cv=cache_cv,
extract_fn=extract_fn
)

_check_param_grid(param_grid)
Expand Down
18 changes: 12 additions & 6 deletions dask_ml/model_selection/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ def warn_fit_failure(error_score, e):


class CVCache:
def __init__(self, splits, pairwise=False, cache=True, num_train_samples=None):
def __init__(self, splits, pairwise=False, cache=True,
num_train_samples=None, extract_fn=None):
self.splits = splits
self.pairwise = pairwise
self.cache = {} if cache else None
self.num_train_samples = num_train_samples
self.extract_fn = extract_fn

def __reduce__(self):
return (
Expand Down Expand Up @@ -127,8 +129,11 @@ def _extract(self, X, y, n, is_x=True, is_train=True):
if self.cache is not None and (n, is_x, is_train) in self.cache:
return self.cache[n, is_x, is_train]

inds = self.splits[n][0] if is_train else self.splits[n][1]
result = _safe_indexing(X if is_x else y, inds)
if self.extract_fn is not None:
result = self.extract_fn(self, X, y, n, is_x, is_train)
else:
inds = self.splits[n][0] if is_train else self.splits[n][1]
result = _safe_indexing(X if is_x else y, inds)

if self.cache is not None:
self.cache[n, is_x, is_train] = result
Expand All @@ -153,9 +158,9 @@ def _extract_pairwise(self, X, y, n, is_train=True):
return result


def cv_split(cv, X, y, groups, is_pairwise, cache):
def cv_split(cv, X, y, groups, is_pairwise, cache, extract_fn):
check_consistent_length(X, y, groups)
return CVCache(list(cv.split(X, y, groups)), is_pairwise, cache, _num_samples(X))
return CVCache(list(cv.split(X, y, groups)), is_pairwise, cache, _num_samples(X), extract_fn=extract_fn)


def cv_n_samples(cvs):
Expand Down Expand Up @@ -318,7 +323,8 @@ def fit_and_score(
if not return_train_score:
X_train = y_train = None

return score(est_and_time, X_test, y_test, X_train, y_train, scorer, error_score)
s = score(est_and_time, X_test, y_test, X_train, y_train, scorer, error_score)
return s


def _store(
Expand Down
96 changes: 96 additions & 0 deletions dask_ml/model_selection/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Dask-ML model_selection-compatible wrappers - EARLY POC"""

import copy
import numpy as np
from .utils import _safe_indexing

try:
import xgboost as xgb
has_xgboost = True
except ImportError:
xgb = None
has_xgboost = False

class XGBoostWrapper:
"""Lightweight, sklearn-like wrapper for XGBoost training
that takes DMatrix as input.
This is a VERY basic Poc."""
def __init__(self, num_boost_round, score_function, **xgb_params):
if not has_xgboost:
raise ImportError("XGBoost is not installed")
self.xgb_params = xgb_params
self.num_boost_round = num_boost_round
self.booster_ = None
self.score_function = score_function

def fit(self, X_dmat, y=None):
self.booster_ = xgb.train(self.xgb_params,
X_dmat,
num_boost_round=self.num_boost_round)
return self

def predict(self, data, output_margin=False):
if isinstance(data, xgb.DMatrix):
test_dmatrix = data
else:
# XXX: base_margin, missing unsupported
test_dmatrix = xgb.DMatrix(data)
class_probs = self.booster_.predict(test_dmatrix,
output_margin=output_margin)
if output_margin:
# If output_margin is active, simply return the scores
return class_probs

if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1)
else:
column_indexes = np.repeat(0, class_probs.shape[0])
column_indexes[class_probs > 0.5] = 1

# Note: no label encoding, unlike sklearn version
return column_indexes

def score(self, X, y=None):
y_pred = self.predict(X)
y_label = X.get_label()
if y_label is None:
y_label = y # XXX not sure if this is right
return self.score_function(y_label, y_pred)

def get_params(self, deep=False):
if deep:
params = copy.deepcopy(self.xgb_params)
else:
params = copy.copy(self.xgb_params)
params["num_boost_round"] = self.num_boost_round
return params

def set_params(self, **params):
params_in = copy.copy(params)
if "num_boost_round" in params_in:
self.num_boost_round = params_in["num_boost_round"]
del params_in["num_boost_round"]
if "xgb_params" in params_in:
self.xgb_params.update(copy.copy(params_in["xgb_params"]))
del params_in["xgb_params"]
self.xgb_params.update(params_in)
return self


def extract_dmatrix(cv, X, y, n, is_x=True, is_train=True):
"""Custom dask-ml extract function, returning DMatrix instead of numpy"""
if not has_xgboost:
raise ImportError("XGBoost is not installed")

if not is_x:
return None

# XXX maybe the interface should just pass in splits instead of cv?
inds = cv.splits[n][0] if is_train else cv.splits[n][1]
x_part = _safe_indexing(X, inds)
y_part = _safe_indexing(y, inds)

# TODO: in practice, there may be additional params like weights
result = xgb.DMatrix(x_part, y_part)

return result