From 8d50bf974a424e410580a75881222245a29b21a3 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Wed, 31 Mar 2021 14:43:25 -0500 Subject: [PATCH] MAINT: better error message about one-hot encoded targets w/ loss="auto" (#218) --- docs/source/advanced.rst | 47 +++++++++++++++++++-- docs/source/quickstart.rst | 19 ++++++--- pyproject.toml | 2 +- scikeras/wrappers.py | 24 ++++++++--- tests/test_loss_auto.py | 84 +++++++++++++++++++++----------------- 5 files changed, 124 insertions(+), 52 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index c771d394c..624b6357b 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -1,6 +1,6 @@ -=================================== -Advanced Usage of SciKeras Wrappers -=================================== +============== +Advanced Usage +============== Wrapper Classes --------------- @@ -128,6 +128,43 @@ offer an easy way to compile and tune compilation parameters. Examples: In all cases, returning an un-compiled model is equivalent to calling ``model.compile(**compile_kwargs)`` within ``model_build_fn``. +.. _loss-selection: + +Loss selection +++++++++++++++ + +If you do not explicitly define a loss, SciKeras attempts to find a loss +that matches the type of target (see :py:func:`sklearn.utils.multiclass.type_of_target`). + +For guidance selecting losses in Keras, please see Jason Brownlee's +excellent article `How to Choose Loss Functions When Training Deep Learning Neural Networks`_ +as well as `Keras Losses docs`_. + +Default losses are selected as follows: + +Classification +.............. + ++-----------+-----------+----------+---------------------------------+ +| # outputs | # classes | encoding | loss | ++===========+===========+==========+=================================+ +| 1 | <= 2 | any | binary crossentropy | ++-----------+-----------+----------+---------------------------------+ +| 1 | >=2 | labels | sparse categorical crossentropy | ++-----------+-----------+----------+---------------------------------+ +| 1 | >=2 | one-hot | unsupported | ++-----------+-----------+----------+---------------------------------+ +| > 1 | -- | -- | unsupported | ++-----------+-----------+----------+---------------------------------+ + +Note that SciKeras will not automatically infer the loss for one-hot encoded targets, +you would need to explicitly specify `loss="categorical_crossentropy"`. + +Regression +.......... + +Regression always defaults to mean squared error. +For multi-output models, Keras will use the sum of each output's loss. Arguments to ``model_build_fn`` ------------------------------- @@ -287,3 +324,7 @@ and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these sc .. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks .. _Keras Metrics docs: https://www.tensorflow.org/api_docs/python/tf/keras/metrics + +.. _Keras Losses docs: https://www.tensorflow.org/api_docs/python/tf/keras/losses + +.. _How to Choose Loss Functions When Training Deep Learning Neural Networks: https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/ \ No newline at end of file diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 8555ea291..66ba19b11 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -38,16 +38,25 @@ it on a toy classification dataset using SciKeras model.add(keras.layers.Activation("softmax")) return model - clf = KerasClassifier( - get_model, - loss="sparse_categorical_crossentropy", - hidden_layer_dim=100, - ) + clf = KerasClassifier(get_model, hidden_layer_dim=100) clf.fit(X, y) y_proba = clf.predict_proba(X) +Note that SciKeras even chooses a loss function and compiles your model. +To override the default loss, simply specify a loss function: + +.. code-block:: diff + + -KerasClassifier(get_model, hidden_layer_dim=100) + +KerasClassifier(get_model, loss="categorical_crossentropy") + +In this case, you would need to specify the loss since SciKeras +will not default to categorical crossentropy, even for one-hot +encoded targets. +See :ref:`loss-selection` for more details. + In an sklearn Pipeline ---------------------- diff --git a/pyproject.toml b/pyproject.toml index 4245a16cb..01e5f126c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ version = "0.2.1" [tool.poetry.dependencies] importlib-metadata = {version = "^3.4.0", python = "<3.8"} -python = ">=3.6.7, <3.9" +python = "^3.11.0" scikit-learn = "^0.22.0" tensorflow = "^2.4.0" diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index 1a7e217d2..1adbeaead 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -22,7 +22,6 @@ from tensorflow.keras import optimizers as optimizers_module from tensorflow.keras.models import Model from tensorflow.keras.utils import register_keras_serializable -from tensorflow.python.types.core import Value from scikeras._utils import ( TFRandomState, @@ -1328,24 +1327,37 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: raise ValueError( 'Only single-output models are supported with `loss="auto"`' ) + loss = None + hint = "" if self.target_type_ == "binary": if self.model_.outputs[0].shape[1] != 1: raise ValueError( "Binary classification expects a model with exactly 1 output unit." ) - compile_kwargs["loss"] = "binary_crossentropy" + loss = "binary_crossentropy" elif self.target_type_ == "multiclass": if self.model_.outputs[0].shape[1] == 1: raise ValueError( "Multi-class targets require the model to have >1 output units." ) - compile_kwargs["loss"] = "sparse_categorical_crossentropy" - else: - raise NotImplementedError( + loss = "sparse_categorical_crossentropy" + elif self.target_type_ == "multilabel-indicator": + # one-hot encoded multiclass problem OR multilabel-indicator problem + hint = ( + "For this type of problem, the following may help:" + '\n - If there is only one class per example, loss="categorical_crossentropy" might be appropriate.' + '\n - If there are multiple classes per example, loss="binary_crossentropy" might be appropriate.' + ) + if loss is None: + msg = ( f'`loss="auto"` is not supported for tasks of type {self.target_type_}.' - " Instead, you must explicitly pass a loss function, for example:" + "\nInstead, you must compile the model yourself or explicitly pass a loss function, for example:" '\n clf = KerasClassifier(..., loss="categorical_crossentropy")' ) + if hint: + msg += f"\n\n{hint}" + raise NotImplementedError(msg) + compile_kwargs["loss"] = loss self.model_.compile(**compile_kwargs) @staticmethod diff --git a/tests/test_loss_auto.py b/tests/test_loss_auto.py index 008d9d4f7..35f12efed 100644 --- a/tests/test_loss_auto.py +++ b/tests/test_loss_auto.py @@ -14,13 +14,13 @@ X = np.random.uniform(size=(n_eg, FEATURES)).astype("float32") -def shallow_net(single_output=False, loss=None, compile=False): +def shallow_net(outputs=None, loss=None, compile=False): model = tf.keras.Sequential() model.add(tf.keras.layers.Input(shape=(FEATURES,))) - if single_output: - model.add(tf.keras.layers.Dense(1)) - else: + if outputs is None: model.add(tf.keras.layers.Dense(N_CLASSES)) + else: + model.add(tf.keras.layers.Dense(outputs)) if compile: model.compile(loss=loss) @@ -45,7 +45,7 @@ def test_user_compiled(loss): """Test to make sure that user compiled classification models work with all classification losses. """ - model__single_output = True if "binary" in loss else False + model__outputs = 1 if "binary" in loss else None if loss == "binary_crosentropy": y = np.random.randint(0, 2, size=(n_eg,)) elif loss == "categorical_crossentropy": @@ -59,7 +59,7 @@ def test_user_compiled(loss): shallow_net, model__compile=True, model__loss=loss, - model__single_output=model__single_output, + model__outputs=model__outputs, ) est.partial_fit(X, y) @@ -69,7 +69,7 @@ def test_user_compiled(loss): class NoEncoderClf(KerasClassifier): """A classifier overriding default target encoding. - This simulates a user implementing custom encoding logic in + This simulates a user implementing custom encoding logic in target_encoder to support multiclass-multioutput or multilabel-indicator, which by default would raise an error. """ @@ -79,11 +79,20 @@ def target_encoder(self): return FunctionTransformer() -@pytest.mark.parametrize("use_case", ["multilabel-indicator", "multiclass-multioutput"]) -def test_classifier_unsupported_multi_output_tasks(use_case): +@pytest.mark.parametrize( + "use_case,wrapper_cls", + [ + ("multilabel-indicator", NoEncoderClf), + ("multiclass-multioutput", NoEncoderClf), + ("classification_w_onehot_targets", KerasClassifier), + ], +) +def test_classifier_unsupported_multi_output_tasks(use_case, wrapper_cls): """Test for an appropriate error for tasks that are not supported by `loss="auto"`. """ + extra = "" + fix_loss = None if use_case == "multiclass-multioutput": y1 = np.random.randint(0, 1, size=len(X)) y2 = np.random.randint(0, 2, size=len(X)) @@ -91,28 +100,37 @@ def test_classifier_unsupported_multi_output_tasks(use_case): elif use_case == "multilabel-indicator": y1 = np.random.randint(0, 1, size=len(X)) y = np.column_stack([y1, y1]) - est = NoEncoderClf(shallow_net, model__compile=False) - with pytest.raises( - NotImplementedError, match='`loss="auto"` is not supported for tasks of type' - ): - est.initialize(X, y) + y[0, :] = 1 + fix_loss = "binary_crossentropy" + extra = f'loss="{fix_loss}" might be appropriate' + elif use_case == "classification_w_onehot_targets": + y = np.random.choice(N_CLASSES, size=len(X)).astype(int) + y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) + fix_loss = "categorical_crossentropy" + extra = f'loss="{fix_loss}" might be appropriate' + match = '`loss="auto"` is not supported for tasks of type' + if extra: + match += f"(.|\n)+{extra}" + with pytest.raises(NotImplementedError, match=match): + wrapper_cls(shallow_net, model__compile=False).initialize(X, y) + if fix_loss: + wrapper_cls(shallow_net, model__compile=False, loss=fix_loss).initialize(X, y) @pytest.mark.parametrize( - "use_case,supported", + "use_case", [ - ("binary_classification", True), - ("binary_classification_w_one_class", True), - ("classification_w_1d_targets", True), - ("classification_w_onehot_targets", False), + "binary_classification", + "binary_classification_w_one_class", + "classification_w_1d_targets", ], ) -def test_classifier_default_loss_only_model_specified(use_case, supported): +def test_classifier_default_loss_only_model_specified(use_case): """Test that KerasClassifier will auto-determine a loss function when only the model is specified. """ - model__single_output = True if "binary" in use_case else False + model__outputs = 1 if "binary" in use_case else None if use_case == "binary_classification": exp_loss = "binary_crossentropy" y = np.random.choice(2, size=len(X)).astype(int) @@ -122,21 +140,11 @@ def test_classifier_default_loss_only_model_specified(use_case, supported): elif use_case == "classification_w_1d_targets": exp_loss = "sparse_categorical_crossentropy" y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int) - elif use_case == "classification_w_onehot_targets": - y = np.random.choice(N_CLASSES, size=len(X)).astype(int) - y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) - est = KerasClassifier(model=shallow_net, model__single_output=model__single_output) + est = KerasClassifier(model=shallow_net, model__outputs=model__outputs) - if supported: - est.fit(X, y=y) - assert loss_name(est.model_.loss) == exp_loss - else: - with pytest.raises( - NotImplementedError, - match='`loss="auto"` is not supported for tasks of type', - ): - est.fit(X, y=y) + est.fit(X, y=y) + assert loss_name(est.model_.loss) == exp_loss assert est.loss == "auto" @@ -148,7 +156,9 @@ def test_regressor_default_loss_only_model_specified(use_case): y = np.random.uniform(size=len(X)) if use_case == "multi_output": y = np.column_stack([y, y]) - est = KerasRegressor(model=shallow_net, model__single_output=True) + est = KerasRegressor( + model=shallow_net, model__outputs=1 if "single" in use_case else 2 + ) est.fit(X, y) assert est.loss == "auto" assert loss_name(est.model_.loss) == "mean_squared_error" @@ -202,7 +212,7 @@ def test_multi_output_support(user_compiled, est_cls): def test_multiclass_single_output_unit(): """Test that multiclass targets requires > 1 output units. """ - est = KerasClassifier(model=shallow_net, model__single_output=True) + est = KerasClassifier(model=shallow_net, model__outputs=1) y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int) with pytest.raises( ValueError, @@ -214,7 +224,7 @@ def test_multiclass_single_output_unit(): def test_binary_multiple_output_units(): """Test that binary targets requires exactly 1 output unit. """ - est = KerasClassifier(model=shallow_net, model__single_output=False) + est = KerasClassifier(model=shallow_net, model__outputs=2) y = np.random.choice(2, size=len(X)).astype(int) with pytest.raises( ValueError,