diff --git a/dislib/classification/csvm/base.py b/dislib/classification/csvm/base.py index ad5dc5f0..8a052e0d 100644 --- a/dislib/classification/csvm/base.py +++ b/dislib/classification/csvm/base.py @@ -192,6 +192,8 @@ def score(self, x, y, collect=False): Test samples. y : ds-array, shape=(n_samples, 1) True labels for x. + collect : bool + When True, a synchronized result is returned. Returns ------- diff --git a/dislib/classification/rf/forest.py b/dislib/classification/rf/forest.py index 8f6c0f2a..dd78b9e0 100644 --- a/dislib/classification/rf/forest.py +++ b/dislib/classification/rf/forest.py @@ -197,7 +197,7 @@ def predict(self, x): return y_pred - def score(self, x, y): + def score(self, x, y, collect=False): """Accuracy classification score. Returns the mean accuracy on the given test data. @@ -209,6 +209,8 @@ def score(self, x, y): The training input samples. y : ds-array, shape (n_samples, 1) The true labels. + collect : bool + When True, a synchronized result is returned. Returns ------- @@ -235,7 +237,9 @@ def score(self, x, y): *tree_predictions) partial_scores.append(subset_score) - return _merge_scores(*partial_scores) + score = _merge_scores(*partial_scores) + + return compss_wait_on(score) if collect else score @task(returns=1) diff --git a/tests/test_rf.py b/tests/test_rf.py index ca111e71..6b4648a9 100644 --- a/tests/test_rf.py +++ b/tests/test_rf.py @@ -1,6 +1,7 @@ import unittest import numpy as np +from parameterized import parameterized from pycompss.api.api import compss_wait_on from sklearn import datasets from sklearn.datasets import make_classification @@ -180,7 +181,8 @@ def test_make_classification_hard_vote_score_mix(self): accuracy = compss_wait_on(rf.score(x_test, y_test)) self.assertGreater(accuracy, 0.7) - def test_iris(self): + @parameterized.expand([(True,), (False,)]) + def test_score_on_iris(self, collect): """Tests RandomForestClassifier with a minimal example.""" x, y = datasets.load_iris(return_X_y=True) ds_fit = ds.array(x[::2], block_size=(30, 2)) @@ -191,7 +193,9 @@ def test_iris(self): rf = RandomForestClassifier(n_estimators=1, max_depth=1, random_state=0) rf.fit(ds_fit, fit_y) - accuracy = compss_wait_on(rf.score(ds_validate, validate_y)) + accuracy = rf.score(ds_validate, validate_y, collect) + if not collect: + accuracy = compss_wait_on(accuracy) # Accuracy should be <= 2/3 for any seed, often exactly equal. self.assertAlmostEqual(accuracy, 2 / 3)