Skip to content

Commit

Permalink
[python-package] add support for pandas nullable types (fixes #4173) (#…
Browse files Browse the repository at this point in the history
…4927)

* map nullable dtypes to regular float dtypes

* cast x3 to float after introducing missing values

* add test for regular dtypes

* use .astype and then values. update nullable_dtypes test and include test for regular numpy dtypes

* more specific allowed dtypes. test no copy when single float dtype df

* use np.find_common_type. set np.float128 to None when it isn't supported

* set default as type(None)

* move tests that use lgb.train to test_engine

* include np.float32 when finding common dtype

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* add linebreak

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
jmoralez and StrikerRUS authored Feb 24, 2022
1 parent 97c8d94 commit f185695
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 16 deletions.
27 changes: 14 additions & 13 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import numpy as np
import scipy.sparse

from .compat import (PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_CategoricalDtype, pd_DataFrame,
pd_Series)
from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series
from .libpath import find_lib_path

ZERO_THRESHOLD = 1e-35
Expand Down Expand Up @@ -502,14 +501,15 @@ def c_int_array(data):


def _get_bad_pandas_dtypes(dtypes):
pandas_dtype_mapper = {'int8': 'int', 'int16': 'int', 'int32': 'int',
'int64': 'int', 'uint8': 'int', 'uint16': 'int',
'uint32': 'int', 'uint64': 'int', 'bool': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float'}
bad_indices = [i for i, dtype in enumerate(dtypes) if (dtype.name not in pandas_dtype_mapper
and (not is_dtype_sparse(dtype)
or dtype.subtype.name not in pandas_dtype_mapper))]
return bad_indices
float128 = getattr(np, 'float128', type(None))

def is_allowed_numpy_dtype(dtype):
return (
issubclass(dtype, (np.integer, np.floating, np.bool_))
and not issubclass(dtype, (np.timedelta64, float128))
)

return [i for i, dtype in enumerate(dtypes) if not is_allowed_numpy_dtype(dtype.type)]


def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
Expand Down Expand Up @@ -546,9 +546,10 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
raise ValueError("DataFrame.dtypes for data must be int, float or bool.\n"
"Did not expect the data types in the following fields: "
f"{bad_index_cols_str}")
data = data.values
if data.dtype != np.float32 and data.dtype != np.float64:
data = data.astype(np.float32)
df_dtypes = [dtype.type for dtype in data.dtypes]
df_dtypes.append(np.float32) # so that the target dtype considers floats
target_dtype = np.find_common_type(df_dtypes, [])
data = data.astype(target_dtype, copy=False).values
else:
if feature_name == 'auto':
feature_name = None
Expand Down
2 changes: 0 additions & 2 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pandas import DataFrame as pd_DataFrame
from pandas import Series as pd_Series
from pandas import concat
from pandas.api.types import is_sparse as is_dtype_sparse
try:
from pandas import CategoricalDtype as pd_CategoricalDtype
except ImportError:
Expand Down Expand Up @@ -34,7 +33,6 @@ def __init__(self, *args, **kwargs):
pass

concat = None
is_dtype_sparse = None

"""matplotlib"""
try:
Expand Down
14 changes: 13 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from scipy import sparse
from sklearn.datasets import dump_svmlight_file, load_svmlight_file, make_blobs
from sklearn.metrics import log_loss
from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import train_test_split

import lightgbm as lgb
Expand Down Expand Up @@ -658,3 +658,15 @@ def custom_eval(y_pred, ds):
_, metric, value, _ = bst.eval(ds, key, feval=custom_eval)[1] # first element is multi_logloss
assert metric == 'custom_logloss'
np.testing.assert_allclose(value, eval_result[key][metric][-1])


@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_no_copy_when_single_float_dtype_dataframe(dtype):
pd = pytest.importorskip('pandas')
X = np.random.rand(10, 2).astype(dtype)
df = pd.DataFrame(X)
# feature names are required to not make a copy (rename makes a copy)
feature_name = ['x1', 'x2']
built_data = lgb.basic._data_from_pandas(df, feature_name, None, None)[0]
assert built_data.dtype == dtype
assert np.shares_memory(X, built_data)
84 changes: 84 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3291,3 +3291,87 @@ def test_record_evaluation_with_cv(train_metric):
np.testing.assert_allclose(
cv_hist[key], eval_result[dataset][f'{metric}-{agg}']
)


def test_pandas_with_numpy_regular_dtypes():
pd = pytest.importorskip('pandas')
uints = ['uint8', 'uint16', 'uint32', 'uint64']
ints = ['int8', 'int16', 'int32', 'int64']
bool_and_floats = ['bool', 'float16', 'float32', 'float64']
rng = np.random.RandomState(42)

n_samples = 100
# data as float64
df = pd.DataFrame({
'x1': rng.randint(0, 2, n_samples),
'x2': rng.randint(1, 3, n_samples),
'x3': 10 * rng.randint(1, 3, n_samples),
'x4': 100 * rng.randint(1, 3, n_samples),
})
df = df.astype(np.float64)
y = df['x1'] * (df['x2'] + df['x3'] + df['x4'])
ds = lgb.Dataset(df, y)
params = {'objective': 'l2', 'num_leaves': 31, 'min_child_samples': 1}
bst = lgb.train(params, ds, num_boost_round=5)
preds = bst.predict(df)

# test all features were used
assert bst.trees_to_dataframe()['split_feature'].nunique() == df.shape[1]
# test the score is better than predicting the mean
baseline = np.full_like(y, y.mean())
assert mean_squared_error(y, preds) < mean_squared_error(y, baseline)

# test all predictions are equal using different input dtypes
for target_dtypes in [uints, ints, bool_and_floats]:
df2 = df.astype({f'x{i}': dtype for i, dtype in enumerate(target_dtypes, start=1)})
assert df2.dtypes.tolist() == target_dtypes
ds2 = lgb.Dataset(df2, y)
bst2 = lgb.train(params, ds2, num_boost_round=5)
preds2 = bst2.predict(df2)
np.testing.assert_allclose(preds, preds2)


def test_pandas_nullable_dtypes():
pd = pytest.importorskip('pandas')
rng = np.random.RandomState(0)
df = pd.DataFrame({
'x1': rng.randint(1, 3, size=100),
'x2': np.linspace(-1, 1, 100),
'x3': pd.arrays.SparseArray(rng.randint(0, 11, size=100)),
'x4': rng.rand(100) < 0.5,
})
# introduce some missing values
df.loc[1, 'x1'] = np.nan
df.loc[2, 'x2'] = np.nan
df.loc[3, 'x4'] = np.nan
# the previous line turns x3 into object dtype in recent versions of pandas
df['x4'] = df['x4'].astype(np.float64)
y = df['x1'] * df['x2'] + df['x3'] * (1 + df['x4'])
y = y.fillna(0)

# train with regular dtypes
params = {'objective': 'l2', 'num_leaves': 31, 'min_child_samples': 1}
ds = lgb.Dataset(df, y)
bst = lgb.train(params, ds, num_boost_round=5)
preds = bst.predict(df)

# convert to nullable dtypes
df2 = df.copy()
df2['x1'] = df2['x1'].astype('Int32')
df2['x2'] = df2['x2'].astype('Float64')
df2['x4'] = df2['x4'].astype('boolean')

# test training succeeds
ds_nullable_dtypes = lgb.Dataset(df2, y)
bst_nullable_dtypes = lgb.train(params, ds_nullable_dtypes, num_boost_round=5)
preds_nullable_dtypes = bst_nullable_dtypes.predict(df2)

trees_df = bst_nullable_dtypes.trees_to_dataframe()
# test all features were used
assert trees_df['split_feature'].nunique() == df.shape[1]
# test the score is better than predicting the mean
baseline = np.full_like(y, y.mean())
assert mean_squared_error(y, preds) < mean_squared_error(y, baseline)

# test equal predictions
np.testing.assert_allclose(preds, preds_nullable_dtypes)

0 comments on commit f185695

Please sign in to comment.