Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: better error message about one-hot encoded targets w/ loss="auto" #218

Merged
merged 19 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,15 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None:
"Multi-class targets require the model to have >1 output units."
)
compile_kwargs["loss"] = "sparse_categorical_crossentropy"
elif hasattr(self, "n_classes_"):
adriangb marked this conversation as resolved.
Show resolved Hide resolved
n_out = self.model_.outputs[0].shape[1]
if n_out != self.n_classes_:
raise ValueError(
"loss='categorical_crossentropy' is expecting the model "
f"to have {self.n_classes_} output neurons, one for each "
"class. However, only {n_out} output neurons were found"
stsievert marked this conversation as resolved.
Show resolved Hide resolved
)
compile_kwargs["loss"] = "categorical_crossentropy"
else:
raise NotImplementedError(
f'`loss="auto"` is not supported for tasks of type {self.target_type_}.'
Expand Down
24 changes: 9 additions & 15 deletions tests/test_loss_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ def test_classifier_unsupported_multi_output_tasks(use_case):


@pytest.mark.parametrize(
"use_case,supported",
"use_case",
[
("binary_classification", True),
("binary_classification_w_one_class", True),
("classification_w_1d_targets", True),
("classification_w_onehot_targets", False),
"binary_classification",
"binary_classification_w_one_class",
"classification_w_1d_targets",
"classification_w_onehot_targets",
],
)
def test_classifier_default_loss_only_model_specified(use_case, supported):
def test_classifier_default_loss_only_model_specified(use_case):
"""Test that KerasClassifier will auto-determine a loss function
when only the model is specified.
"""
Expand All @@ -123,20 +123,14 @@ def test_classifier_default_loss_only_model_specified(use_case, supported):
exp_loss = "sparse_categorical_crossentropy"
y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int)
elif use_case == "classification_w_onehot_targets":
exp_loss = "categorical_crossentropy"
y = np.random.choice(N_CLASSES, size=len(X)).astype(int)
y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))

est = KerasClassifier(model=shallow_net, model__single_output=model__single_output)

if supported:
est.fit(X, y=y)
assert loss_name(est.model_.loss) == exp_loss
else:
with pytest.raises(
NotImplementedError,
match='`loss="auto"` is not supported for tasks of type',
):
est.fit(X, y=y)
est.fit(X, y=y)
assert loss_name(est.model_.loss) == exp_loss
assert est.loss == "auto"


Expand Down