Skip to content

Commit

Permalink
Change sklearn testing classifier to models
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <beat.buesser@ie.ibm.com>
  • Loading branch information
Beat Buesser committed Oct 1, 2020
1 parent 06aaa2f commit 20a5fec
Show file tree
Hide file tree
Showing 39 changed files with 37 additions and 23 deletions.
60 changes: 37 additions & 23 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,47 +1236,61 @@ def train_step(model, images, labels):


def get_tabular_classifier_scikit_list(clipped=False):
model_list_names = [
"decisionTreeClassifier",
"extraTreeClassifier",
"adaBoostClassifier",
"baggingClassifier",
"extraTreesClassifier",
"gradientBoostingClassifier",
"randomForestClassifier",
"logisticRegression",
"svc",
"linearSVC",
]

from art.estimators.classification.scikitlearn import (
ScikitlearnDecisionTreeClassifier,
ScikitlearnExtraTreeClassifier,
ScikitlearnAdaBoostClassifier,
ScikitlearnBaggingClassifier,
ScikitlearnExtraTreesClassifier,
ScikitlearnGradientBoostingClassifier,
ScikitlearnRandomForestClassifier,
ScikitlearnLogisticRegression,
ScikitlearnSVC,
ScikitlearnSVC,
)

model_list_names = {
"decisionTreeClassifier": ScikitlearnDecisionTreeClassifier,
# "extraTreeClassifier": ScikitlearnExtraTreeClassifier,
"adaBoostClassifier": ScikitlearnAdaBoostClassifier,
"baggingClassifier": ScikitlearnBaggingClassifier,
"extraTreesClassifier": ScikitlearnExtraTreesClassifier,
"gradientBoostingClassifier": ScikitlearnGradientBoostingClassifier,
"randomForestClassifier": ScikitlearnRandomForestClassifier,
"logisticRegression": ScikitlearnLogisticRegression,
"svc": ScikitlearnSVC,
"linearSVC": ScikitlearnSVC,
}

classifier_list = list()

if clipped:
classifier_list = [
# os.path.join(os.path.dirname(os.path.dirname(__file__)),'utils/resources/models', 'W_DENSE3_IRIS.npy')
pickle.load(
for model_name, model_class in model_list_names.items():
model = pickle.load(
open(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"utils/resources/models/scikit/",
model_name + "iris_clipped.sav",
"scikit-" + model_name + "-iris-clipped.pickle",
),
"rb",
)
)
for model_name in model_list_names
]
classifier_list.append(model_class(model=model, clip_values=(0, 1)))
else:
classifier_list = [
pickle.load(
for model_name, model_class in model_list_names.items():
model = pickle.load(
open(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"utils/resources/models/scikit/",
model_name + "iris_unclipped.sav",
"scikit-" + model_name + "-iris-unclipped.pickle",
),
"rb",
)
)
for model_name in model_list_names
]
classifier_list.append(model_class(model=model, clip_values=None))

return classifier_list

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed utils/resources/models/scikit/svciris_clipped.sav
Binary file not shown.
Binary file removed utils/resources/models/scikit/svciris_unclipped.sav
Binary file not shown.

0 comments on commit 20a5fec

Please sign in to comment.