Skip to content

[BUG] .predict_proba on fitted Pipeline object with a ColumnTransformer step raises exception #4368

Closed
@EFulmer

Description

@EFulmer

Describe the bug
When using the ColumnTransformer from cuml.experimental.preprocessing in an already fit Pipeline, the methods predict/predict_proba raise exceptions stating that X has a mismatched number of features, even though the data is of the same shape as the DataFrame passed to fit.

Steps/Code to reproduce bug

Here's a minimal example, preserving the types and shape of my real data and the structure of the pipeline (same encoders, imputers, and classifier used):

import cudf
from cuml.ensemble import RandomForestClassifier
from cuml.experimental.preprocessing import ColumnTransformer
from cuml.preprocessing import OneHotEncoder, SimpleImputer, StandardScaler, TargetEncoder
from cuml.pipeline import Pipeline
import cupy


id_vars = ["id"]
id_transformer = Pipeline(
    steps=[
        ("imputer", SimpleImputer(strategy="mean", missing_values=cupy.NaN))
    ]
)

categorical_vars = ["cat"]
categorical_transformer = Pipeline(
    [
        ("ordinal", OneHotEncoder(sparse=False)),
        ("imputer", SimpleImputer(strategy="most_frequent")),
    ]
)

numeric_vars = ["num"]
numeric_transformer = Pipeline(
    steps=[
        ("imputer", SimpleImputer(strategy="mean", missing_values=cupy.NaN)),
        ("scaler", StandardScaler()),
    ]
)

preprocessor = ColumnTransformer(
    transformers=[
        ("id", id_transformer, id_vars),
        ("categorical", categorical_transformer, categorical_vars),
        ("numeric", numeric_transformer, numeric_vars),
    ],
    remainder="drop",
)
model = Pipeline(
    steps=[
        ("preprocessor", preprocessor),
        ("classifier", RandomForestClassifier(n_bins=2)),  # n_bins set to 2 because of smaller "toy" input set
    ]
)

df_train = cudf.DataFrame(
    [{"id": 1, "cat": "a", "num": 1.0, "target": 0, "extra": 5},
     {"id": 2, "cat": "a", "num": 2.0, "target": 1, "extra": -1},
     {"id": 3, "cat": "b", "num": 3.0, "target": 1, "extra": 100}]
)

X_train = df_train.drop("target", axis=1)
y_train = df_train["target"]

df_test = cudf.DataFrame(
    [{"id": 4, "cat": "b", "num": 2.0, "target": 1, "extra": 17}]
)

X_test = df_test.drop("target", axis=1)
y_test = df_test["target"]

model.fit(X_train, y_train)
print(model.predict_proba(X_test))

The stack trace output is immediately below. Using sklearn equivalents, predict_proba executes without exception.

/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/internals/api_decorators.py:567: UserWarning: To use pickling or GPU-based prediction first train using float32 data to fit the estimator
  ret_val = func(*args, **kwargs)
Traceback (most recent call last):
  File "min_failing_example.py", line 57, in <module>
    model.predict_proba(X)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/sklearn/utils/metaestimators.py", line 113, in <lambda>
    out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)  # noqa
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/sklearn/pipeline.py", line 535, in predict_proba
    Xt = transform.transform(Xt)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/internals/api_decorators.py", line 586, in inner_get
    ret_val = func(*args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py", line 934, in transform
    Xs = self._fit_transform(X, None, _transform_one, fitted=True)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py", line 806, in _fit_transform
    return Parallel(n_jobs=self.n_jobs)(
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/joblib/parallel.py", line 1044, in __call__
    while self.dispatch_one_batch(iterator):
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/joblib/parallel.py", line 859, in dispatch_one_batch
    self._dispatch(tasks)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/joblib/parallel.py", line 777, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 208, in apply_async
    result = ImmediateResult(func)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 572, in __init__
    self.results = batch()
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/joblib/parallel.py", line 262, in __call__
    return [func(*args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/joblib/parallel.py", line 262, in <listcomp>
    return [func(*args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py", line 361, in __call__
    return self.function(*args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py", line 287, in _transform_one
    res = transformer.transform(X).to_output('cupy')
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/sklearn/utils/metaestimators.py", line 113, in <lambda>
    out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)  # noqa
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/sklearn/pipeline.py", line 647, in transform
    Xt = transform.transform(Xt)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/internals/api_decorators.py", line 586, in inner_get
    ret_val = func(*args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_imputation.py", line 415, in transform
    X = self._validate_input(X, in_fit=False)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_imputation.py", line 280, in _validate_input
    raise ve
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_imputation.py", line 270, in _validate_input
    X = self._validate_data(X, reset=in_fit,
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/utils/skl_dependencies.py", line 127, in _validate_data
    self._check_n_features(X, reset=reset)
  File "/opt/conda/envs/rapids/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/utils/skl_dependencies.py", line 68, in _check_n_features
    raise ValueError(
ValueError: X has 2 features, but this BaseMetaClass is expecting 1 features as input.

Expected behavior
Calling predict_proba on a fitted pipeline should return an array of predictions.

If there is a larger error with the input to predict_proba or predict, a more descriptive error would be very much appreciated as well.

Environment details (please complete the following information):

  • Environment location: Google Cloud Platform
  • Linux Distro/Architecture: Ubuntu 18.04.6 LTS
  • GPU Model/Driver: Tesla T4 and driver 455.32.00
  • CUDA: 11.1
  • Method of cuDF & cuML install: Docker; image rapidsai/rapidsai-core:21.08-cuda11.0-base-ubuntu18.04-py3.8

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions