Skip to content

Commit

Permalink
Merge pull request #348 from bsc-wdc/csvm_score_collect
Browse files Browse the repository at this point in the history
Synchronization of the score in csvm
  • Loading branch information
michal-choinski authored Jul 28, 2021
2 parents 39359ab + 14ddb24 commit 66e98b4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
6 changes: 4 additions & 2 deletions dislib/classification/csvm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def decision_function(self, x):
reg_shape=(x._reg_shape[0], 1),
shape=(x.shape[0], 1), sparse=False)

def score(self, x, y):
def score(self, x, y, collect=False):
"""
Returns the mean accuracy on the given test data and labels.
Expand All @@ -207,7 +207,9 @@ def score(self, x, y):
partial = _score(x_row._blocks, y_row._blocks, self._clf)
partial_scores.append(partial)

return _merge_scores(*partial_scores)
score = _merge_scores(*partial_scores)

return compss_wait_on(score) if collect else score

def _check_initial_parameters(self):
gamma = self.gamma
Expand Down
8 changes: 6 additions & 2 deletions tests/test_csvm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import numpy as np
from parameterized import parameterized
from pycompss.api.api import compss_wait_on

import dislib as ds
Expand Down Expand Up @@ -136,7 +137,8 @@ def test_predict(self):
self.assertTrue(l1 == l2 == l5 == 0)
self.assertTrue(l3 == l4 == l6 == 1)

def test_score(self):
@parameterized.expand([(True,), (False,)])
def test_score(self, collect):
seed = 666

# negative points belong to class 1, positives to 0
Expand All @@ -157,7 +159,9 @@ def test_score(self):
x_test = ds.array(np.array([p1, p2, p3, p4]), (2, 2))
y_test = ds.array(np.array([0, 0, 1, 1]).reshape(-1, 1), (2, 1))

accuracy = compss_wait_on(csvm.score(x_test, y_test))
accuracy = csvm.score(x_test, y_test, collect)
if not collect:
accuracy = compss_wait_on(accuracy)

self.assertEqual(accuracy, 1.0)

Expand Down

0 comments on commit 66e98b4

Please sign in to comment.