Skip to content

Commit

Permalink
MAINT: better error message about one-hot encoded targets w/ loss="au…
Browse files Browse the repository at this point in the history
…to" (#218)
  • Loading branch information
stsievert authored Mar 31, 2021
1 parent ca868f5 commit 8d50bf9
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 52 deletions.
47 changes: 44 additions & 3 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
===================================
Advanced Usage of SciKeras Wrappers
===================================
==============
Advanced Usage
==============

Wrapper Classes
---------------
Expand Down Expand Up @@ -128,6 +128,43 @@ offer an easy way to compile and tune compilation parameters. Examples:
In all cases, returning an un-compiled model is equivalent to
calling ``model.compile(**compile_kwargs)`` within ``model_build_fn``.

.. _loss-selection:

Loss selection
++++++++++++++

If you do not explicitly define a loss, SciKeras attempts to find a loss
that matches the type of target (see :py:func:`sklearn.utils.multiclass.type_of_target`).

For guidance selecting losses in Keras, please see Jason Brownlee's
excellent article `How to Choose Loss Functions When Training Deep Learning Neural Networks`_
as well as `Keras Losses docs`_.

Default losses are selected as follows:

Classification
..............

+-----------+-----------+----------+---------------------------------+
| # outputs | # classes | encoding | loss |
+===========+===========+==========+=================================+
| 1 | <= 2 | any | binary crossentropy |
+-----------+-----------+----------+---------------------------------+
| 1 | >=2 | labels | sparse categorical crossentropy |
+-----------+-----------+----------+---------------------------------+
| 1 | >=2 | one-hot | unsupported |
+-----------+-----------+----------+---------------------------------+
| > 1 | -- | -- | unsupported |
+-----------+-----------+----------+---------------------------------+

Note that SciKeras will not automatically infer the loss for one-hot encoded targets,
you would need to explicitly specify `loss="categorical_crossentropy"`.

Regression
..........

Regression always defaults to mean squared error.
For multi-output models, Keras will use the sum of each output's loss.

Arguments to ``model_build_fn``
-------------------------------
Expand Down Expand Up @@ -287,3 +324,7 @@ and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these sc
.. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks

.. _Keras Metrics docs: https://www.tensorflow.org/api_docs/python/tf/keras/metrics

.. _Keras Losses docs: https://www.tensorflow.org/api_docs/python/tf/keras/losses

.. _How to Choose Loss Functions When Training Deep Learning Neural Networks: https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/
19 changes: 14 additions & 5 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,25 @@ it on a toy classification dataset using SciKeras
model.add(keras.layers.Activation("softmax"))
return model
clf = KerasClassifier(
get_model,
loss="sparse_categorical_crossentropy",
hidden_layer_dim=100,
)
clf = KerasClassifier(get_model, hidden_layer_dim=100)
clf.fit(X, y)
y_proba = clf.predict_proba(X)
Note that SciKeras even chooses a loss function and compiles your model.
To override the default loss, simply specify a loss function:

.. code-block:: diff
-KerasClassifier(get_model, hidden_layer_dim=100)
+KerasClassifier(get_model, loss="categorical_crossentropy")
In this case, you would need to specify the loss since SciKeras
will not default to categorical crossentropy, even for one-hot
encoded targets.
See :ref:`loss-selection` for more details.

In an sklearn Pipeline
----------------------

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ version = "0.2.1"

[tool.poetry.dependencies]
importlib-metadata = {version = "^3.4.0", python = "<3.8"}
python = ">=3.6.7, <3.9"
python = "^3.11.0"
scikit-learn = "^0.22.0"
tensorflow = "^2.4.0"

Expand Down
24 changes: 18 additions & 6 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from tensorflow.keras import optimizers as optimizers_module
from tensorflow.keras.models import Model
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.python.types.core import Value

from scikeras._utils import (
TFRandomState,
Expand Down Expand Up @@ -1328,24 +1327,37 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None:
raise ValueError(
'Only single-output models are supported with `loss="auto"`'
)
loss = None
hint = ""
if self.target_type_ == "binary":
if self.model_.outputs[0].shape[1] != 1:
raise ValueError(
"Binary classification expects a model with exactly 1 output unit."
)
compile_kwargs["loss"] = "binary_crossentropy"
loss = "binary_crossentropy"
elif self.target_type_ == "multiclass":
if self.model_.outputs[0].shape[1] == 1:
raise ValueError(
"Multi-class targets require the model to have >1 output units."
)
compile_kwargs["loss"] = "sparse_categorical_crossentropy"
else:
raise NotImplementedError(
loss = "sparse_categorical_crossentropy"
elif self.target_type_ == "multilabel-indicator":
# one-hot encoded multiclass problem OR multilabel-indicator problem
hint = (
"For this type of problem, the following may help:"
'\n - If there is only one class per example, loss="categorical_crossentropy" might be appropriate.'
'\n - If there are multiple classes per example, loss="binary_crossentropy" might be appropriate.'
)
if loss is None:
msg = (
f'`loss="auto"` is not supported for tasks of type {self.target_type_}.'
" Instead, you must explicitly pass a loss function, for example:"
"\nInstead, you must compile the model yourself or explicitly pass a loss function, for example:"
'\n clf = KerasClassifier(..., loss="categorical_crossentropy")'
)
if hint:
msg += f"\n\n{hint}"
raise NotImplementedError(msg)
compile_kwargs["loss"] = loss
self.model_.compile(**compile_kwargs)

@staticmethod
Expand Down
84 changes: 47 additions & 37 deletions tests/test_loss_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
X = np.random.uniform(size=(n_eg, FEATURES)).astype("float32")


def shallow_net(single_output=False, loss=None, compile=False):
def shallow_net(outputs=None, loss=None, compile=False):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(FEATURES,)))
if single_output:
model.add(tf.keras.layers.Dense(1))
else:
if outputs is None:
model.add(tf.keras.layers.Dense(N_CLASSES))
else:
model.add(tf.keras.layers.Dense(outputs))

if compile:
model.compile(loss=loss)
Expand All @@ -45,7 +45,7 @@ def test_user_compiled(loss):
"""Test to make sure that user compiled classification models work with all
classification losses.
"""
model__single_output = True if "binary" in loss else False
model__outputs = 1 if "binary" in loss else None
if loss == "binary_crosentropy":
y = np.random.randint(0, 2, size=(n_eg,))
elif loss == "categorical_crossentropy":
Expand All @@ -59,7 +59,7 @@ def test_user_compiled(loss):
shallow_net,
model__compile=True,
model__loss=loss,
model__single_output=model__single_output,
model__outputs=model__outputs,
)
est.partial_fit(X, y)

Expand All @@ -69,7 +69,7 @@ def test_user_compiled(loss):

class NoEncoderClf(KerasClassifier):
"""A classifier overriding default target encoding.
This simulates a user implementing custom encoding logic in
This simulates a user implementing custom encoding logic in
target_encoder to support multiclass-multioutput or
multilabel-indicator, which by default would raise an error.
"""
Expand All @@ -79,40 +79,58 @@ def target_encoder(self):
return FunctionTransformer()


@pytest.mark.parametrize("use_case", ["multilabel-indicator", "multiclass-multioutput"])
def test_classifier_unsupported_multi_output_tasks(use_case):
@pytest.mark.parametrize(
"use_case,wrapper_cls",
[
("multilabel-indicator", NoEncoderClf),
("multiclass-multioutput", NoEncoderClf),
("classification_w_onehot_targets", KerasClassifier),
],
)
def test_classifier_unsupported_multi_output_tasks(use_case, wrapper_cls):
"""Test for an appropriate error for tasks that are not supported
by `loss="auto"`.
"""
extra = ""
fix_loss = None
if use_case == "multiclass-multioutput":
y1 = np.random.randint(0, 1, size=len(X))
y2 = np.random.randint(0, 2, size=len(X))
y = np.column_stack([y1, y2])
elif use_case == "multilabel-indicator":
y1 = np.random.randint(0, 1, size=len(X))
y = np.column_stack([y1, y1])
est = NoEncoderClf(shallow_net, model__compile=False)
with pytest.raises(
NotImplementedError, match='`loss="auto"` is not supported for tasks of type'
):
est.initialize(X, y)
y[0, :] = 1
fix_loss = "binary_crossentropy"
extra = f'loss="{fix_loss}" might be appropriate'
elif use_case == "classification_w_onehot_targets":
y = np.random.choice(N_CLASSES, size=len(X)).astype(int)
y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))
fix_loss = "categorical_crossentropy"
extra = f'loss="{fix_loss}" might be appropriate'
match = '`loss="auto"` is not supported for tasks of type'
if extra:
match += f"(.|\n)+{extra}"
with pytest.raises(NotImplementedError, match=match):
wrapper_cls(shallow_net, model__compile=False).initialize(X, y)
if fix_loss:
wrapper_cls(shallow_net, model__compile=False, loss=fix_loss).initialize(X, y)


@pytest.mark.parametrize(
"use_case,supported",
"use_case",
[
("binary_classification", True),
("binary_classification_w_one_class", True),
("classification_w_1d_targets", True),
("classification_w_onehot_targets", False),
"binary_classification",
"binary_classification_w_one_class",
"classification_w_1d_targets",
],
)
def test_classifier_default_loss_only_model_specified(use_case, supported):
def test_classifier_default_loss_only_model_specified(use_case):
"""Test that KerasClassifier will auto-determine a loss function
when only the model is specified.
"""

model__single_output = True if "binary" in use_case else False
model__outputs = 1 if "binary" in use_case else None
if use_case == "binary_classification":
exp_loss = "binary_crossentropy"
y = np.random.choice(2, size=len(X)).astype(int)
Expand All @@ -122,21 +140,11 @@ def test_classifier_default_loss_only_model_specified(use_case, supported):
elif use_case == "classification_w_1d_targets":
exp_loss = "sparse_categorical_crossentropy"
y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int)
elif use_case == "classification_w_onehot_targets":
y = np.random.choice(N_CLASSES, size=len(X)).astype(int)
y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))

est = KerasClassifier(model=shallow_net, model__single_output=model__single_output)
est = KerasClassifier(model=shallow_net, model__outputs=model__outputs)

if supported:
est.fit(X, y=y)
assert loss_name(est.model_.loss) == exp_loss
else:
with pytest.raises(
NotImplementedError,
match='`loss="auto"` is not supported for tasks of type',
):
est.fit(X, y=y)
est.fit(X, y=y)
assert loss_name(est.model_.loss) == exp_loss
assert est.loss == "auto"


Expand All @@ -148,7 +156,9 @@ def test_regressor_default_loss_only_model_specified(use_case):
y = np.random.uniform(size=len(X))
if use_case == "multi_output":
y = np.column_stack([y, y])
est = KerasRegressor(model=shallow_net, model__single_output=True)
est = KerasRegressor(
model=shallow_net, model__outputs=1 if "single" in use_case else 2
)
est.fit(X, y)
assert est.loss == "auto"
assert loss_name(est.model_.loss) == "mean_squared_error"
Expand Down Expand Up @@ -202,7 +212,7 @@ def test_multi_output_support(user_compiled, est_cls):
def test_multiclass_single_output_unit():
"""Test that multiclass targets requires > 1 output units.
"""
est = KerasClassifier(model=shallow_net, model__single_output=True)
est = KerasClassifier(model=shallow_net, model__outputs=1)
y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int)
with pytest.raises(
ValueError,
Expand All @@ -214,7 +224,7 @@ def test_multiclass_single_output_unit():
def test_binary_multiple_output_units():
"""Test that binary targets requires exactly 1 output unit.
"""
est = KerasClassifier(model=shallow_net, model__single_output=False)
est = KerasClassifier(model=shallow_net, model__outputs=2)
y = np.random.choice(2, size=len(X)).astype(int)
with pytest.raises(
ValueError,
Expand Down

0 comments on commit 8d50bf9

Please sign in to comment.