Skip to content

Commit

Permalink
Merge pull request bsc-wdc#350 from bsc-wdc/rf_score_synchronization
Browse files Browse the repository at this point in the history
Synchronizing the score in RF
  • Loading branch information
michal-choinski authored Jul 30, 2021
2 parents 66e98b4 + 13e766c commit 64faec6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
2 changes: 2 additions & 0 deletions dislib/classification/csvm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def score(self, x, y, collect=False):
Test samples.
y : ds-array, shape=(n_samples, 1)
True labels for x.
collect : bool
When True, a synchronized result is returned.
Returns
-------
Expand Down
8 changes: 6 additions & 2 deletions dislib/classification/rf/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def predict(self, x):

return y_pred

def score(self, x, y):
def score(self, x, y, collect=False):
"""Accuracy classification score.
Returns the mean accuracy on the given test data.
Expand All @@ -209,6 +209,8 @@ def score(self, x, y):
The training input samples.
y : ds-array, shape (n_samples, 1)
The true labels.
collect : bool
When True, a synchronized result is returned.
Returns
-------
Expand All @@ -235,7 +237,9 @@ def score(self, x, y):
*tree_predictions)
partial_scores.append(subset_score)

return _merge_scores(*partial_scores)
score = _merge_scores(*partial_scores)

return compss_wait_on(score) if collect else score


@task(returns=1)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_rf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import numpy as np
from parameterized import parameterized
from pycompss.api.api import compss_wait_on
from sklearn import datasets
from sklearn.datasets import make_classification
Expand Down Expand Up @@ -180,7 +181,8 @@ def test_make_classification_hard_vote_score_mix(self):
accuracy = compss_wait_on(rf.score(x_test, y_test))
self.assertGreater(accuracy, 0.7)

def test_iris(self):
@parameterized.expand([(True,), (False,)])
def test_score_on_iris(self, collect):
"""Tests RandomForestClassifier with a minimal example."""
x, y = datasets.load_iris(return_X_y=True)
ds_fit = ds.array(x[::2], block_size=(30, 2))
Expand All @@ -191,7 +193,9 @@ def test_iris(self):
rf = RandomForestClassifier(n_estimators=1, max_depth=1,
random_state=0)
rf.fit(ds_fit, fit_y)
accuracy = compss_wait_on(rf.score(ds_validate, validate_y))
accuracy = rf.score(ds_validate, validate_y, collect)
if not collect:
accuracy = compss_wait_on(accuracy)

# Accuracy should be <= 2/3 for any seed, often exactly equal.
self.assertAlmostEqual(accuracy, 2 / 3)
Expand Down

0 comments on commit 64faec6

Please sign in to comment.