-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python-package] fix mypy errors about custom eval and metric functions #5790
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,32 +33,50 @@ | |
scipy.sparse.spmatrix | ||
] | ||
_LGBM_ScikitCustomObjectiveFunction = Union[ | ||
# f(labels, preds) | ||
Callable[ | ||
[np.ndarray, np.ndarray], | ||
[Optional[np.ndarray], np.ndarray], | ||
Tuple[np.ndarray, np.ndarray] | ||
], | ||
# f(labels, preds, weights) | ||
Callable[ | ||
[np.ndarray, np.ndarray, np.ndarray], | ||
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], | ||
Tuple[np.ndarray, np.ndarray] | ||
], | ||
# f(labels, preds, weights, group) | ||
Callable[ | ||
[np.ndarray, np.ndarray, np.ndarray, np.ndarray], | ||
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], | ||
Tuple[np.ndarray, np.ndarray] | ||
], | ||
] | ||
_LGBM_ScikitCustomEvalFunction = Union[ | ||
# f(labels, preds) | ||
Callable[ | ||
[np.ndarray, np.ndarray], | ||
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] | ||
[Optional[np.ndarray], np.ndarray], | ||
_LGBM_EvalFunctionResultType | ||
], | ||
Callable[ | ||
[np.ndarray, np.ndarray, np.ndarray], | ||
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] | ||
[Optional[np.ndarray], np.ndarray], | ||
List[_LGBM_EvalFunctionResultType] | ||
], | ||
# f(labels, preds, weights) | ||
Callable[ | ||
[np.ndarray, np.ndarray, np.ndarray, np.ndarray], | ||
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] | ||
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], | ||
_LGBM_EvalFunctionResultType | ||
], | ||
Callable[ | ||
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], | ||
List[_LGBM_EvalFunctionResultType] | ||
], | ||
# f(labels, preds, weights, group) | ||
Callable[ | ||
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], | ||
_LGBM_EvalFunctionResultType | ||
], | ||
Callable[ | ||
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], | ||
List[_LGBM_EvalFunctionResultType] | ||
] | ||
Comment on lines
+72
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Splitting cases like Callable[
[np.ndarray, np.ndarray],
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
] into Union[
Callable[
[np.ndarray, np.ndarray],
_LGBM_EvalFunctionResultType
],
Callable[
[np.ndarray, np.ndarray],
List[_LGBM_EvalFunctionResultType]
]
] helps
And I think it's slightly more correct... I'd expect people to provide a custom metric function that returns a single tuple on every iteration or one that returns a list of tuples on each iteration, but not one that could return either of those. |
||
] | ||
_LGBM_ScikitEvalMetricType = Union[ | ||
str, | ||
|
@@ -135,11 +153,11 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. | |
labels = dataset.get_label() | ||
argc = len(signature(self.func).parameters) | ||
if argc == 2: | ||
grad, hess = self.func(labels, preds) | ||
grad, hess = self.func(labels, preds) # type: ignore[call-arg] | ||
elif argc == 3: | ||
grad, hess = self.func(labels, preds, dataset.get_weight()) | ||
grad, hess = self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] | ||
elif argc == 4: | ||
grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) | ||
grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore [call-arg] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
else: | ||
raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") | ||
return grad, hess | ||
|
@@ -213,11 +231,11 @@ def __call__( | |
labels = dataset.get_label() | ||
argc = len(signature(self.func).parameters) | ||
if argc == 2: | ||
return self.func(labels, preds) | ||
return self.func(labels, preds) # type: ignore[call-arg] | ||
elif argc == 3: | ||
return self.func(labels, preds, dataset.get_weight()) | ||
return self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] | ||
elif argc == 4: | ||
return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) | ||
return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore[call-arg] | ||
else: | ||
raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") | ||
|
||
|
@@ -819,7 +837,7 @@ def _get_meta_data(collection, name, i): | |
num_boost_round=self.n_estimators, | ||
valid_sets=valid_sets, | ||
valid_names=eval_names, | ||
feval=eval_metrics_callable, | ||
feval=eval_metrics_callable, # type: ignore[arg-type] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This
I'm not sure why, but I think that's maybe related to some issues it has with comparing lists to union types including lists? e.g. python/mypy#6463 This PR proposes just ignoring that error for now... the use of custom eval metric functions is well-covered by the unit tests in |
||
init_model=init_model, | ||
feature_name=feature_name, | ||
callbacks=callbacks | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add all these
Optional
s?mypy
is struggling with the facts thatDataset.get_field()
can returnNone
.LightGBM/python-package/lightgbm/basic.py
Line 2704 in 2fe2bf0
LightGBM/python-package/lightgbm/sklearn.py
Line 127 in 2fe2bf0
LightGBM/python-package/lightgbm/basic.py
Line 2393 in 2fe2bf0
This PR proposes updating the type hints for custom metric and objective functions to match that behavior.
I intentionally chose not to update the user-facing docs about custom metric and objective functions to reflect that the
label
,group
, andweights
passed to these functions can technically beNone
... in almost all situations, they should be non-None
. I don't think complicating the docs is worth it.