Skip to content

Commit 7d74671

Browse files
wdevazelhesperimosocordiae
authored andcommitted
Fix random state in algorithms (#234)
1 parent 54c9d89 commit 7d74671

File tree

5 files changed

+29
-8
lines changed

5 files changed

+29
-8
lines changed

metric_learn/lmnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def fit(self, X, y):
210210
init = self.init
211211
self.components_ = _initialize_components(output_dim, X, y, init,
212212
self.verbose,
213-
self.random_state)
213+
random_state=self.random_state)
214214
required_k = np.bincount(label_inds).min()
215215
if self.k > required_k:
216216
raise ValueError('not enough class labels for specified k'

metric_learn/lsml.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,8 @@ def _fit(self, quadruplets, weights=None):
5656
else:
5757
prior = self.prior
5858
M, prior_inv = _initialize_metric_mahalanobis(quadruplets, prior,
59-
return_inverse=True,
60-
strict_pd=True,
61-
matrix_name='prior')
59+
return_inverse=True, strict_pd=True, matrix_name='prior',
60+
random_state=self.random_state)
6261

6362
step_sizes = np.logspace(-10, 0, 10)
6463
# Keep track of the best step size and the loss at that step.

metric_learn/nca.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ def fit(self, X, y):
174174
init = 'auto'
175175
else:
176176
init = self.init
177-
A = _initialize_components(n_components, X, labels, init, self.verbose)
177+
A = _initialize_components(n_components, X, labels, init, self.verbose,
178+
self.random_state)
178179

179180
# Run NCA
180181
mask = labels[:, np.newaxis] == labels[np.newaxis, :]

metric_learn/sdml.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ def _fit(self, pairs, y):
6969
else:
7070
prior = self.prior
7171
_, prior_inv = _initialize_metric_mahalanobis(pairs, prior,
72-
return_inverse=True,
73-
strict_pd=True,
74-
matrix_name='prior')
72+
return_inverse=True, strict_pd=True, matrix_name='prior',
73+
random_state=self.random_state)
7574
diff = pairs[:, 0] - pairs[:, 1]
7675
loss_matrix = (diff.T * y).dot(diff)
7776
emp_cov = prior_inv + self.balance_param * loss_matrix

test/test_mahalanobis_mixin.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,25 @@ def test_singular_array_init_or_prior(estimator, build_dataset, w0):
652652
with pytest.raises(LinAlgError) as raised_err:
653653
model.fit(input_data, labels)
654654
assert str(raised_err.value) == msg
655+
656+
657+
@pytest.mark.integration
658+
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
659+
ids=ids_metric_learners)
660+
def test_deterministic_initialization(estimator, build_dataset):
661+
"""Test that estimators that have a prior or an init are deterministic
662+
when it is set to to random and when the random_state is fixed."""
663+
input_data, labels, _, X = build_dataset()
664+
model = clone(estimator)
665+
if hasattr(estimator, 'init'):
666+
model.set_params(init='random')
667+
if hasattr(estimator, 'prior'):
668+
model.set_params(prior='random')
669+
model1 = clone(model)
670+
set_random_state(model1, 42)
671+
model1 = model1.fit(input_data, labels)
672+
model2 = clone(model)
673+
set_random_state(model2, 42)
674+
model2 = model2.fit(input_data, labels)
675+
np.testing.assert_allclose(model1.get_mahalanobis_matrix(),
676+
model2.get_mahalanobis_matrix())

0 commit comments

Comments
 (0)