Skip to content

Commit 94f8875

Browse files
FIX Remove unnecessary restriction on number of samples in IncrementalPCA (scikit-learn#30224)
1 parent 5ca2f4f commit 94f8875

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- :class:`~sklearn.decomposition.IncrementalPCA`
2+
will now only raise a ``ValueError`` when the number of samples in the
3+
input data to ``partial_fit`` is less than the number of components
4+
on the first call to ``partial_fit``. Subsequent calls to ``partial_fit``
5+
no longer face this restriction.
6+
By :user:`Thomas Gessey-Jones <ThomasGesseyJonesPX>`

sklearn/decomposition/_incremental_pca.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,11 @@ def partial_fit(self, X, y=None, check_input=True):
306306
"more rows than columns for IncrementalPCA "
307307
"processing" % (self.n_components, n_features)
308308
)
309-
elif not self.n_components <= n_samples:
309+
elif self.n_components > n_samples and first_pass:
310310
raise ValueError(
311-
"n_components=%r must be less or equal to "
312-
"the batch number of samples "
313-
"%d." % (self.n_components, n_samples)
311+
f"n_components={self.n_components} must be less or equal to "
312+
f"the batch number of samples {n_samples} for the first "
313+
"partial_fit call."
314314
)
315315
else:
316316
self.n_components_ = self.n_components

sklearn/decomposition/tests/test_incremental_pca.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,13 @@ def test_incremental_pca_validation():
139139
):
140140
IncrementalPCA(n_components, batch_size=10).fit(X)
141141

142-
# Tests that n_components is also <= n_samples.
142+
# Test that n_components is also <= n_samples in first call to partial fit.
143143
n_components = 3
144144
with pytest.raises(
145145
ValueError,
146146
match=(
147-
"n_components={} must be"
148-
" less or equal to the batch number of"
149-
" samples {}".format(n_components, n_samples)
147+
f"n_components={n_components} must be less or equal to the batch "
148+
f"number of samples {n_samples} for the first partial_fit call."
150149
),
151150
):
152151
IncrementalPCA(n_components=n_components).partial_fit(X)
@@ -233,6 +232,27 @@ def test_incremental_pca_batch_signs():
233232
assert_almost_equal(np.sign(i), np.sign(j), decimal=6)
234233

235234

235+
def test_incremental_pca_partial_fit_small_batch():
236+
# Test that there is no minimum batch size after the first partial_fit
237+
# Non-regression test
238+
rng = np.random.RandomState(1999)
239+
n, p = 50, 3
240+
X = rng.randn(n, p) # spherical data
241+
X[:, 1] *= 0.00001 # make middle component relatively small
242+
X += [5, 4, 3] # make a large mean
243+
244+
n_components = p
245+
pipca = IncrementalPCA(n_components=n_components)
246+
pipca.partial_fit(X[:n_components])
247+
for idx in range(n_components, n):
248+
pipca.partial_fit(X[idx : idx + 1])
249+
250+
pca = PCA(n_components=n_components)
251+
pca.fit(X)
252+
253+
assert_allclose(pca.components_, pipca.components_, atol=1e-3)
254+
255+
236256
def test_incremental_pca_batch_values():
237257
# Test that components_ values are stable over batch sizes.
238258
rng = np.random.RandomState(1999)

0 commit comments

Comments
 (0)