batchstats is a Python package for computing statistics on data that arrives in batches. It's perfect for streaming data or datasets too large to fit into memory.
For detailed information, please check out the full documentation.
Install batchstats using pip:
pip install batchstatsOr with conda:
conda install -c conda-forge batchstatsHere's how to compute the mean and variance of a dataset in batches:
import numpy as np
from batchstats import BatchMean, BatchVar
# Simulate a data stream
data_stream = (np.random.randn(100, 10) for _ in range(10))
# Initialize the stat objects
batch_mean = BatchMean()
batch_var = BatchVar()
# Process each batch
for batch in data_stream:
batch_mean.update_batch(batch)
batch_var.update_batch(batch)
# Get the final result
mean = batch_mean()
variance = batch_var()
print(f"Mean shape: {mean.shape}")
print(f"Variance shape: {variance.shape}")batchstats handles n-dimensional np.ndarray inputs and allows specifying multiple axes for reduction, just like numpy.
import numpy as np
from batchstats import BatchMean
# Create a 3D data stream
data_stream = (np.random.rand(10, 5, 8) for _ in range(5))
# Compute the mean over the last two axes (1 and 2)
batch_mean_3d = BatchMean(axis=(1, 2))
for batch in data_stream:
batch_mean_3d.update_batch(batch)
mean_3d = batch_mean_3d()
print(f"3D Mean shape: {mean_3d.shape}")batchstats provides BatchNan* classes to handle NaN values, similar to numpy's nan* functions.
import numpy as np
from batchstats import BatchNanMean
# Create data with NaNs
data = np.random.randn(1000, 5)
data[::10] = np.nan
# Compute the mean, ignoring NaNs
nan_mean = BatchNanMean().update_batch(data)()
print(f"NaN-aware mean shape: {nan_mean.shape}")batchstats supports a variety of common statistics:
BatchSum/BatchNanSumBatchMean/BatchNanMeanBatchMin/BatchNanMinBatchMax/BatchNanMaxBatchPeakToPeak/BatchNanPeakToPeakBatchVarBatchStdBatchCovBatchCorr
For more details on each class, see the API Reference.