Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH All sklearn estimators in trusted list #237

Merged
merged 9 commits into from
Dec 12, 2022
4 changes: 2 additions & 2 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from ._audit import Node, get_tree
from ._trusted_types import PRIMITIVE_TYPE_NAMES
from ._trusted_types import PRIMITIVE_TYPE_NAMES, SKLEARN_ESTIMATOR_TYPE_NAMES
from ._utils import (
LoadContext,
SaveContext,
Expand Down Expand Up @@ -383,7 +383,7 @@ def __init__(

self.children = {"attrs": attrs}
# TODO: what do we trust?
self.trusted = self._get_trusted(trusted, [])
self.trusted = self._get_trusted(trusted, default=SKLEARN_ESTIMATOR_TYPE_NAMES)

def _construct(self):
cls = gettype(self.module_name, self.class_name)
Expand Down
10 changes: 10 additions & 0 deletions skops/io/_trusted_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from sklearn.utils import all_estimators

from ._utils import get_type_name

PRIMITIVES_TYPES = [int, float, str, bool]

PRIMITIVE_TYPE_NAMES = ["builtins." + t.__name__ for t in PRIMITIVES_TYPES]

SKLEARN_ESTIMATOR_TYPE_NAMES = [
get_type_name(estimator_class)
for _, estimator_class in all_estimators()
if get_type_name(estimator_class).startswith("sklearn.")
]
10 changes: 1 addition & 9 deletions skops/io/tests/test_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,4 @@ def test_complex_pipeline_untrusted_set():

untrusted = get_untrusted_types(data=dumps(clf))
type_names = [x.split(".")[-1] for x in untrusted]
assert type_names == [
"sqrt",
"square",
"LogisticRegression",
"FeatureUnion",
"Pipeline",
"StandardScaler",
"FunctionTransformer",
]
assert type_names == ["sqrt", "square"]
8 changes: 7 additions & 1 deletion skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def get_input(estimator):
@pytest.mark.parametrize(
"estimator", _tested_estimators(), ids=_get_check_estimator_ids
)
def test_can_persist_fitted(estimator, request):
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
def test_can_persist_fitted(estimator):
"""Check that fitted estimators can be persisted and return the right results."""
set_random_state(estimator, random_state=0)

Expand All @@ -491,6 +491,12 @@ def test_can_persist_fitted(estimator, request):
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(estimator.__dict__, loaded.__dict__)

# test that most sklearn estimators are not in untrusted_types
sklearn_untrusted_types = [
type_ for type_ in untrusted_types if type_.startswith("sklearn.")
]
assert len(sklearn_untrusted_types) == 0

for method in [
"predict",
"predict_proba",
Expand Down