Skip to content

Commit

Permalink
suppprt pandas categorical (microsoft#193)
Browse files Browse the repository at this point in the history
* suppprt pandas categorical

* refine logic

* make default=auto

* fix train/valid categorical codes

* add test

* unify set _predictor

* fix tests

* fix warning

* support feature_name=int
  • Loading branch information
wxchan authored and guolinke committed Jan 12, 2017
1 parent 00e5b24 commit 6c248d3
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ before_install:

install:
- sudo apt-get install -y libopenmpi-dev openmpi-bin build-essential
- conda install --yes atlas numpy scipy scikit-learn
- conda install --yes atlas numpy scipy scikit-learn pandas
- pip install pep8


Expand Down
43 changes: 26 additions & 17 deletions docs/Python-API.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
- [Booster](Python-API.md#booster)

* [Training API](Python-API.md#training-api)
- [train](Python-API.md#trainparams-train_set-num_boost_round100-valid_setsnone-valid_namesnone-fobjnone-fevalnone-init_modelnone-feature_namenone-categorical_featurenone-early_stopping_roundsnone-evals_resultnone-verbose_evaltrue-learning_ratesnone-callbacksnone)
- [cv](Python-API.md#cvparams-train_set-num_boost_round10-nfold5-stratifiedfalse-shuffletrue-metricsnone-fobjnone-fevalnone-init_modelnone-feature_namenone-categorical_featurenone-early_stopping_roundsnone-fpreprocnone-verbose_evalnone-show_stdvtrue-seed0-callbacksnone)
- [train](Python-API.md#trainparams-train_set-num_boost_round100-valid_setsnone-valid_namesnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-categorical_featureauto-early_stopping_roundsnone-evals_resultnone-verbose_evaltrue-learning_ratesnone-callbacksnone)
- [cv](Python-API.md#cvparams-train_set-num_boost_round10-nfold5-stratifiedfalse-shuffletrue-metricsnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-categorical_featureauto-early_stopping_roundsnone-fpreprocnone-verbose_evalnone-show_stdvtrue-seed0-callbacksnone)

* [Scikit-learn API](Python-API.md#scikit-learn-api)
- [Common Methods](Python-API.md#common-methods)
Expand All @@ -31,7 +31,7 @@ The methods of each Class is in alphabetical order.

###Dataset

####__init__(data, label=None, max_bin=255, reference=None, weight=None, group=None, silent=False, feature_name=None, categorical_feature=None, params=None, free_raw_data=True)
####__init__(data, label=None, max_bin=255, reference=None, weight=None, group=None, silent=False, feature_name='auto', categorical_feature='auto', params=None, free_raw_data=True)

Parameters
----------
Expand All @@ -50,12 +50,14 @@ The methods of each Class is in alphabetical order.
Group/query size for dataset
silent : boolean, optional
Whether print messages during construction
feature_name : list of str
feature_name : list of str, or 'auto'
Feature names
categorical_feature : list of str or list of int
If 'auto' and data is pandas DataFrame, use data columns name
categorical_feature : list of str or int, or 'auto'
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well)
If 'auto' and data is pandas DataFrame, use pandas categorical columns
params : dict, optional
Other parameters
free_raw_data : Bool
Expand Down Expand Up @@ -445,7 +447,7 @@ The methods of each Class is in alphabetical order.

##Training API

####train(params, train_set, num_boost_round=100, valid_sets=None, valid_names=None, fobj=None, feval=None, init_model=None, feature_name=None, categorical_feature=None, early_stopping_rounds=None, evals_result=None, verbose_eval=True, learning_rates=None, callbacks=None)
####train(params, train_set, num_boost_round=100, valid_sets=None, valid_names=None, fobj=None, feval=None, init_model=None, feature_name='auto', categorical_feature='auto', early_stopping_rounds=None, evals_result=None, verbose_eval=True, learning_rates=None, callbacks=None)

Train with given parameters.

Expand All @@ -468,12 +470,14 @@ The methods of each Class is in alphabetical order.
Note: should return (eval_name, eval_result, is_higher_better) of list of this
init_model : file name of lightgbm model or 'Booster' instance
model used for continued train
feature_name : list of str
feature_name : list of str, or 'auto'
Feature names
categorical_feature : list of str or list of int
If 'auto' and data is pandas DataFrame, use data columns name
categorical_feature : list of str or int, or 'auto'
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well)
If 'auto' and data is pandas DataFrame, use pandas categorical columns
early_stopping_rounds: int
Activates early stopping.
Requires at least one validation data and one metric
Expand Down Expand Up @@ -513,7 +517,7 @@ The methods of each Class is in alphabetical order.
booster : a trained booster model


####cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, feature_name=None, categorical_feature=None, early_stopping_rounds=None, fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, callbacks=None)
####cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, feature_name='auto', categorical_feature='auto', early_stopping_rounds=None, fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, callbacks=None)

Cross-validation with given paramaters.

Expand Down Expand Up @@ -541,11 +545,14 @@ The methods of each Class is in alphabetical order.
Custom evaluation function.
init_model : file name of lightgbm model or 'Booster' instance
model used for continued train
feature_name : list of str
feature_name : list of str, or 'auto'
Feature names
categorical_feature : list of str or int
Categorical features, type int represents index,
If 'auto' and data is pandas DataFrame, use data columns name
categorical_feature : list of str or int, or 'auto'
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well)
If 'auto' and data is pandas DataFrame, use pandas categorical columns
early_stopping_rounds: int
Activates early stopping. CV error needs to decrease at least
every <early_stopping_rounds> round(s) to continue.
Expand Down Expand Up @@ -693,7 +700,7 @@ The methods of each Class is in alphabetical order.
X_leaves : array_like, shape=[n_samples, n_trees]


####fit(X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_sample_weight=None, eval_init_score=None, eval_group=None, eval_metric=None, early_stopping_rounds=None, verbose=True, feature_name=None, categorical_feature=None, callbacks=None)
####fit(X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_sample_weight=None, eval_init_score=None, eval_group=None, eval_metric=None, early_stopping_rounds=None, verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None)

Fit the gradient boosting model.

Expand Down Expand Up @@ -724,12 +731,14 @@ The methods of each Class is in alphabetical order.
early_stopping_rounds : int
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
feature_name : list of str
feature_name : list of str, or 'auto'
Feature names
categorical_feature : list of str or int
If 'auto' and data is pandas DataFrame, use data columns name
categorical_feature : list of str or int, or 'auto'
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well).
type str represents feature names (need to specify feature_name as well)
If 'auto' and data is pandas DataFrame, use pandas categorical columns
callbacks : list of callback functions
List of callback functions that are applied at each iteration.
See Callbacks in Python-API.md for more information.
Expand Down Expand Up @@ -823,7 +832,7 @@ The methods of each Class is in alphabetical order.

###LGBMRanker

####fit(X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_sample_weight=None, eval_init_score=None, eval_group=None, eval_metric='ndcg', eval_at=1, early_stopping_rounds=None, verbose=True, feature_name=None, categorical_feature=None, callbacks=None)
####fit(X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_sample_weight=None, eval_init_score=None, eval_group=None, eval_metric='ndcg', eval_at=1, early_stopping_rounds=None, verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None)

Most arguments are same as Common Methods except:

Expand Down
51 changes: 41 additions & 10 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,29 @@ def __pred_for_csc(self, csc, num_iteration, predict_type):
'float32': 'float', 'float64': 'float', 'bool': 'int'}


def _data_from_pandas(data):
def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
if isinstance(data, DataFrame):
cat_cols = data.select_dtypes(include=['category']).columns
if not pandas_categorical: # train dataset
pandas_categorical = (data[col].cat.categories for col in cat_cols)
else:
if len(cat_cols) != len(pandas_categorical):
raise ValueError('train and valid dataset categorical_feature do not match.')
for col, category in zip(cat_cols, pandas_categorical):
if data[col].cat.categories != category:
data[col] = data[col].cat.set_categories(category)
if len(cat_cols): # cat_cols is pandas Index object
data = data.copy() # not alter origin DataFrame
data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes)
if categorical_feature is not None:
if feature_name is None:
feature_name = data.columns
if categorical_feature == 'auto':
categorical_feature = cat_cols
else:
categorical_feature += cat_cols
if feature_name == 'auto':
feature_name = data.columns
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
bad_fields = [data.columns[i] for i, dtype in
Expand All @@ -464,7 +485,12 @@ def _data_from_pandas(data):
msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
data = data.values.astype('float')
return data
else:
if feature_name == 'auto':
feature_name = None
if categorical_feature == 'auto':
categorical_feature = None
return data, feature_name, categorical_feature, pandas_categorical


def _label_from_pandas(label):
Expand All @@ -482,7 +508,7 @@ class Dataset(object):
"""Dataset in LightGBM."""
def __init__(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, silent=False,
feature_name=None, categorical_feature=None, params=None,
feature_name='auto', categorical_feature='auto', params=None,
free_raw_data=True):
"""
Parameters
Expand All @@ -502,12 +528,14 @@ def __init__(self, data, label=None, max_bin=255, reference=None,
Group/query size for dataset
silent : boolean, optional
Whether print messages during construction
feature_name : list of str
feature_name : list of str, or 'auto'
Feature names
categorical_feature : list of str or int
If 'auto' and data is pandas DataFrame, use data columns name
categorical_feature : list of str or int, or 'auto'
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well)
If 'auto' and data is pandas DataFrame, use pandas categorical columns
params: dict, optional
Other parameters
free_raw_data: Bool
Expand All @@ -527,6 +555,7 @@ def __init__(self, data, label=None, max_bin=255, reference=None,
self.free_raw_data = free_raw_data
self.used_indices = None
self._predictor = None
self.pandas_categorical = None

def __del__(self):
self._free_handle()
Expand All @@ -538,12 +567,12 @@ def _free_handle(self):

def _lazy_init(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, predictor=None,
silent=False, feature_name=None,
categorical_feature=None, params=None):
silent=False, feature_name='auto',
categorical_feature='auto', params=None):
if data is None:
self.handle = None
return
data = _data_from_pandas(data)
data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas(data, feature_name, categorical_feature, self.pandas_categorical)
label = _label_from_pandas(label)
self.data_has_header = False
"""process for args"""
Expand Down Expand Up @@ -760,7 +789,8 @@ def create_valid(self, data, label=None, weight=None, group=None,
ret = Dataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group=group, silent=silent, params=params,
free_raw_data=self.free_raw_data)
ret._set_predictor(self._predictor)
ret._predictor = self._predictor
ret.pandas_categorical = self.pandas_categorical
return ret

def subset(self, used_indices, params=None):
Expand All @@ -777,6 +807,7 @@ def subset(self, used_indices, params=None):
ret = Dataset(None, reference=self, feature_name=self.feature_name,
categorical_feature=self.categorical_feature, params=params)
ret._predictor = self._predictor
ret.pandas_categorical = self.pandas_categorical
ret.used_indices = used_indices
return ret

Expand Down Expand Up @@ -948,7 +979,7 @@ def set_feature_name(self, feature_name):
if self.handle is not None and feature_name is not None:
if len(feature_name) != self.num_feature():
raise ValueError("Length of feature_name({}) and num_feature({}) don't match".format(len(feature_name), self.num_feature()))
c_feature_name = [c_str(name) for name in feature_name]
c_feature_name = [c_str(str(name)) for name in feature_name]
_safe_call(_LIB.LGBM_DatasetSetFeatureNames(
self.handle,
c_array(ctypes.c_char_p, c_feature_name),
Expand Down
19 changes: 12 additions & 7 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def train(params, train_set, num_boost_round=100,
valid_sets=None, valid_names=None,
fobj=None, feval=None, init_model=None,
feature_name=None, categorical_feature=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, callbacks=None):
"""
Expand All @@ -42,12 +42,14 @@ def train(params, train_set, num_boost_round=100,
Note: should return (eval_name, eval_result, is_higher_better) of list of this
init_model : file name of lightgbm model or 'Booster' instance
model used for continued train
feature_name : list of str
feature_name : list of str, or 'auto'
Feature names
categorical_feature : list of str or int
If 'auto' and data is pandas DataFrame, use data columns name
categorical_feature : list of str or int, or 'auto'
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well)
If 'auto' and data is pandas DataFrame, use pandas categorical columns
early_stopping_rounds: int
Activates early stopping.
Requires at least one validation data and one metric
Expand Down Expand Up @@ -267,7 +269,7 @@ def _agg_cv_result(raw_results):

def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
shuffle=True, metrics=None, fobj=None, feval=None, init_model=None,
feature_name=None, categorical_feature=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, fpreproc=None,
verbose_eval=None, show_stdv=True, seed=0,
callbacks=None):
Expand Down Expand Up @@ -298,11 +300,14 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
Custom evaluation function.
init_model : file name of lightgbm model or 'Booster' instance
model used for continued train
feature_name : list of str
feature_name : list of str, or 'auto'
Feature names
categorical_feature : list of str or int
Categorical features, type int represents index,
If 'auto' and data is pandas DataFrame, use data columns name
categorical_feature : list of str or int, or 'auto'
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well)
If 'auto' and data is pandas DataFrame, use pandas categorical columns
early_stopping_rounds: int
Activates early stopping. CV error needs to decrease at least
every <early_stopping_rounds> round(s) to continue.
Expand Down
14 changes: 8 additions & 6 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def fit(self, X, y,
eval_init_score=None, eval_group=None,
eval_metric=None,
early_stopping_rounds=None, verbose=True,
feature_name=None, categorical_feature=None,
feature_name='auto', categorical_feature='auto',
callbacks=None):
"""
Fit the gradient boosting model
Expand Down Expand Up @@ -311,12 +311,14 @@ def fit(self, X, y,
early_stopping_rounds : int
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
feature_name : list of str
feature_name : list of str, or 'auto'
Feature names
categorical_feature : list of str or int
If 'auto' and data is pandas DataFrame, use data columns name
categorical_feature : list of str or int, or 'auto'
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well)
If 'auto' and data is pandas DataFrame, use pandas categorical columns
callbacks : list of callback functions
List of callback functions that are applied at each iteration.
See Callbacks in Python-API.md for more information.
Expand Down Expand Up @@ -506,7 +508,7 @@ def fit(self, X, y,
eval_init_score=None,
eval_metric="l2",
early_stopping_rounds=None, verbose=True,
feature_name=None, categorical_feature=None, callbacks=None):
feature_name='auto', categorical_feature='auto', callbacks=None):

super(LGBMRegressor, self).fit(X, y, sample_weight=sample_weight,
init_score=init_score, eval_set=eval_set,
Expand Down Expand Up @@ -552,7 +554,7 @@ def fit(self, X, y,
eval_init_score=None,
eval_metric="binary_logloss",
early_stopping_rounds=None, verbose=True,
feature_name=None, categorical_feature=None,
feature_name='auto', categorical_feature='auto',
callbacks=None):
self._le = LGBMLabelEncoder().fit(y)
y = self._le.transform(y)
Expand Down Expand Up @@ -653,7 +655,7 @@ def fit(self, X, y,
eval_init_score=None, eval_group=None,
eval_metric='ndcg', eval_at=1,
early_stopping_rounds=None, verbose=True,
feature_name=None, categorical_feature=None,
feature_name='auto', categorical_feature='auto',
callbacks=None):
"""
Most arguments like common methods except following:
Expand Down
Loading

0 comments on commit 6c248d3

Please sign in to comment.