Skip to content

chore: Add Benchmark Scripts and Performance Comparison of LitData vs FFCV for Streaming ImageNet #572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
May 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fb4fa8d
docs: Add performance comparison for streaming Imagenet dataset
bhimrazy Apr 27, 2025
b4cb1c3
Merge branch 'main' into docs/litdata-vs-ffcv-benchmarks
bhimrazy May 12, 2025
2f504cb
add optimize script for imagenet
bhimrazy May 14, 2025
bd8a087
update optimize script
bhimrazy May 14, 2025
7cc6df7
fix: update print statements for clarity in optimize_imagenet script
bhimrazy May 14, 2025
b636c00
feat: add streaming benchmark for ImageNet dataset with support for J…
bhimrazy May 14, 2025
abd3e15
update optimize script
bhimrazy May 14, 2025
05a24e3
add readme
bhimrazy May 14, 2025
0bb5a61
Merge branch 'main' into docs/litdata-vs-ffcv-benchmarks
bhimrazy May 14, 2025
9106307
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2025
bf78b3f
update litdata optimize
bhimrazy May 15, 2025
d42be4f
add script to convert dataset
bhimrazy May 15, 2025
f3e9123
add write and stream script
bhimrazy May 15, 2025
b6150c3
add install script
bhimrazy May 15, 2025
1bddc9d
update to include dropt last
bhimrazy May 15, 2025
df12160
update readme
bhimrazy May 15, 2025
611c256
add readme
bhimrazy May 15, 2025
efc4411
update readme
bhimrazy May 15, 2025
9a96239
update readme
bhimrazy May 15, 2025
b1d0b7d
update readme
bhimrazy May 15, 2025
0d78a16
update readme
bhimrazy May 15, 2025
92dbc26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2025
be6b007
add missing docstring
bhimrazy May 15, 2025
2d06e53
Merge branch 'main' into docs/litdata-vs-ffcv-benchmarks
bhimrazy May 15, 2025
d7b1b5a
update ffcv stream
bhimrazy May 17, 2025
1c6adff
update stream litdata
bhimrazy May 17, 2025
e27f29f
update
bhimrazy May 17, 2025
1f4c5ba
update stream for ffcv
bhimrazy May 17, 2025
3804b5c
update script
bhimrazy May 17, 2025
6da072d
add gitignore
bhimrazy May 17, 2025
09bfe23
update
bhimrazy May 17, 2025
34141b7
update
bhimrazy May 17, 2025
eea69e8
update
bhimrazy May 17, 2025
d2bc95d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,16 @@ Speed to stream Imagenet 1.2M from other cloud storage providers:
|---|---|---|---|
| Cloudflare R2 | LitData | **5335** | **5630** |

Speed to stream Imagenet 1.2M from local disk with ffcv vs LitData:
| Framework | Dataset Mode | Dataset Size @ 256px | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) |
|---|---|---|---|---|
| LitData | PIL RAW | 168 GB | 6647 | 6398 |
| LitData | JPEG 90% | 12 GB | 6553 | 6537 |
| ffcv (os_cache=True) | RAW | 170 GB | 7263 | 6698 |
| ffcv (os_cache=False) | RAW | 170 GB | 7556 | 8169 |
| ffcv(os_cache=True) | JPEG 90% | 20 GB | 7653 | 8051 |
| ffcv(os_cache=False) | JPEG 90% | 20 GB | 8149 | 8607 |

 

## Time to optimize data
Expand Down
1 change: 1 addition & 0 deletions benchmarks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data
24 changes: 24 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# LitData Benchmarks

This directory contains simple, ready-to-use scripts for benchmarking data loading and optimization with LitData and FFCV.

- Use the `litdata/` folder for LitData-based optimization and streaming.
- Use the `ffcv/` folder for FFCV-based dataset conversion and streaming.

You can compare both approaches for your own datasets and training pipelines.

## How to use

- See the README in each subfolder for step-by-step instructions.
- All scripts are CLI-based and easy to run.

## Why benchmarks?

Benchmarks help you:
- Measure data loading speed and efficiency
- Compare different formats and pipelines
- Choose the best setup for your training

---

For more details, check the README in each subfolder.
95 changes: 95 additions & 0 deletions benchmarks/ffcv/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# FFCV Benchmarks

This folder contains scripts to convert, write, and stream datasets using FFCV for benchmarking.

## 1. Prepare the Dataset

First, copy the raw ImageNet dataset to your machine (if not already present):

```sh
s5cmd cp "s3://imagenet-1m-template/raw/train/*" data/imagenet-1m-raw/train
```

Convert the raw ImageNet synset folders to PyTorch ImageFolder format (class index folders):

```sh
python convert_imagenet.py --data_dir data/imagenet-1m-raw/train
```

## 2. Install FFCV

Install the FFCV library (if not already installed):

```sh
sh install_ffcv.sh
```

## 3. Write FFCV Dataset

Prepare imagenet dataset to FFCV format. Example for different settings:

- (max 256px, 0% JPEG, quality 100)

```sh
python write_imagenet.py \
--cfg.dataset=imagenet \
--cfg.split=train \
--cfg.data_dir=/path/to/imagenet/train \
--cfg.write_path=/your/output/path/train_256_0.0_100.ffcv \
--cfg.max_resolution=256 \
--cfg.write_mode=proportion \
--cfg.compress_probability=0.0 \
--cfg.jpeg_quality=100
```

- (max 256px, 100% JPEG, quality 90)

```sh
python write_imagenet.py \
--cfg.dataset=imagenet \
--cfg.split=train \
--cfg.data_dir=/path/to/imagenet/train \
--cfg.write_path=/your/output/path/train_256_100.0_90.ffcv \
--cfg.max_resolution=256 \
--cfg.write_mode=proportion \
--cfg.compress_probability=100.0 \
--cfg.jpeg_quality=90
```

## 4. Stream FFCV Dataset

Stream an FFCV .ffcv dataset for benchmarking or training:

```sh
python stream_imagenet.py \
--cfg.data_path=/path/to/train_256_0.0_100.ffcv \
--cfg.batch_size=256 \
--cfg.num_workers=32 \
--cfg.epochs=2
```

---

These scripts are easy to use and work with both local and cloud datasets. For more details, see the script docstrings or run with `--help`.

## 5. Benchmark LitData vs FFCV

You can use already prepared datasets to quickly run your benchmarks. Simply copy the optimized datasets from S3 to your teamspace, then run the provided streaming or benchmarking scripts.

Example S3 structure:

```
s3://xxxxxxx/datasets/imagenet-1m-ffcv/
train_256_0.0_100.ffcv
train_256_100.0_90.ffcv
s3://xxxxxxx/datasets/imagenet-1m-litdata/
train_256_jpg_90/
train_256_raw_pil/
```

To extract the real S3 path for a dataset in your teamspace, use:

```sh
python3 -c "from litdata.streaming.resolver import _resolve_dir; path=_resolve_dir('/teamspace/datasets/imagenet-1m-litdata/'); print(path.url)"
```
You can also prepare the datasets yourself using the earlier steps if you prefer.
91 changes: 91 additions & 0 deletions benchmarks/ffcv/convert_imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Convert an ImageNet dataset (with synset folders) to PyTorch ImageFolder format (class index folders).

Usage:
python convert_imagenet.py --data_dir /path/to/imagenet-raw/train

This script helps you:
- Change the folder structure from synset names (like n01440764) to class index numbers (0, 1, 2, ...)
- Move all images into their new class index folders
- Clean up by removing the old synset folders

After running this script, your dataset will be ready to use with torchvision.datasets.ImageFolder or
any PyTorch dataloader expecting class index folders.
"""

import json
import os
from argparse import ArgumentParser
from functools import lru_cache
from glob import glob

import requests
from torchvision.datasets import ImageFolder
from tqdm import tqdm


@lru_cache(maxsize=1)
def load_imagenet_class_index():
"""Load the ImageNet class index mapping from class names to class indices."""
class_index_url = "https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json"
try:
response = requests.get(class_index_url, timeout=10)
response.raise_for_status()
class_index_data = response.json()
return {v[0]: int(k) for k, v in class_index_data.items()}
except (requests.RequestException, json.JSONDecodeError) as e:
raise RuntimeError(f"Failed to load ImageNet class index: {e}")


if __name__ == "__main__":
parser = ArgumentParser(description="Convert ImageNet synset folders to PyTorch ImageFolder style.")
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="Path to the ImageNet dataset directory (containing synset folders).",
)
args = parser.parse_args()
imagenet_dir = args.data_dir
if not os.path.exists(imagenet_dir):
raise FileNotFoundError(f"The specified directory does not exist: {imagenet_dir}")

print("[INFO] Loading ImageNet class index mapping...")
class_index_mapping = load_imagenet_class_index()

# Create class index folders if not present
print("[INFO] Creating class index folders...")
for _, class_index in class_index_mapping.items():
folder_path = os.path.join(imagenet_dir, str(class_index))
os.makedirs(folder_path, exist_ok=True)

# Move images from synset folders to class index folders
print("[INFO] Moving images to class index folders...")
for file_path in tqdm(glob(f"{imagenet_dir}/*/*.*"), desc="Moving files"):
dirname = os.path.basename(os.path.dirname(file_path))
if not dirname.startswith("n"):
continue
class_index = class_index_mapping[dirname]
destination_path = os.path.join(imagenet_dir, str(class_index), os.path.basename(file_path))
os.rename(file_path, destination_path)

# Remove old synset folders
print("[INFO] Removing old synset folders...")
for folder in tqdm(glob(f"{imagenet_dir}/*"), desc="Removing old folders"):
if os.path.basename(folder).startswith("n") and os.path.isdir(folder):
try:
os.rmdir(folder)
except OSError:
print(f"[WARNING] Could not remove folder (not empty?): {folder}")

print("[SUCCESS] Conversion complete.")
print("[INFO] All images are now organized in class index folders.")
print("[INFO] You can now use the dataset with torchvision.datasets.ImageFolder.")

# Show a sample from the resulting dataset
dataset = ImageFolder(root=imagenet_dir, transform=None)
print(f"[INFO] Number of classes: {len(dataset.classes)}")
print(f"[INFO] Example class indices: {dataset.classes[:5]}")
if len(dataset) > 0:
print(f"[INFO] Sample image path: {dataset.samples[0][0]}, class index: {dataset.samples[0][1]}")
else:
print("[INFO] No images found in the converted dataset.")
16 changes: 16 additions & 0 deletions benchmarks/ffcv/install_ffcv.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
echo "=== FFCV Installer ==="
OS="$(uname -s)"
echo "Detected OS: $OS"

# Install ffcv dependencies
conda install -y -c conda-forge libjpeg-turbo
conda install -y pkg-config compilers opencv -c conda-forge
pip uninstall -y opencv-python-headless numba
pip install opencv-python-headless numba
pip install --force-reinstall "numpy>=1.21,<2"
pip install ffcv

echo "Verifying FFCV installation..."
python3 -c "import ffcv; print('✅ FFCV installed successfully!')"

echo "=== Done ==="
108 changes: 108 additions & 0 deletions benchmarks/ffcv/stream_imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Stream an FFCV ImageNet dataset for benchmarking.

Adapted from: https://github.com/libffcv/ffcv-imagenet/blob/main/train_imagenet.py

This script streams an FFCV ImageNet dataset and benchmarks the streaming speed.
It uses the FFCV library to load and process the dataset efficiently.

Example usage:
python stream_imagenet.py --cfg.data_path=/path/to/train_256_0.0_100.ffcv
"""

import os
import time

import lightning as L
import numpy as np
import torch
import torchvision.transforms.v2 as T
from fastargs import Param, Section, get_current_config
from fastargs.decorators import param, section
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import NormalizeImage, RandomHorizontalFlip, Squeeze, ToTensor, ToTorchImage
from tqdm import tqdm

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255

Section("cfg", "arguments for streaming FFCV dataset").params(
data_path=Param(str, "Path to the FFCV .ffcv file", required=True),
batch_size=Param(int, "Batch size for streaming", default=256),
num_workers=Param(int, "Number of workers for loader", default=os.cpu_count()),
drop_last=Param(
bool,
"Drop the last incomplete batch (default: True)",
default=False,
),
epochs=Param(int, "Number of epochs to run benchmark", default=2),
order=Param(str, "Order: SEQUENTIAL or RANDOM or QUASI_RANDOM", default="SEQUENTIAL"),
os_cache=Param(bool, "Use OS cache if the dataset can fit in memory", default=False),
normalize=Param(
bool,
"If True, applies normalization using ImageNet mean and std; if False, uses scaling via T.ToDtype.",
default=False,
),
)


@section("cfg")
@param("data_path")
@param("batch_size")
@param("num_workers")
@param("drop_last")
@param("epochs")
@param("order")
@param("os_cache")
@param("normalize")
def main(data_path, batch_size, num_workers, drop_last, epochs, order, os_cache, normalize):
"""Stream and benchmark an FFCV ImageNet dataset."""
L.seed_everything(42)

# Set up FFCV pipelines
image_pipeline = [
RandomResizedCropRGBImageDecoder((224, 224)),
RandomHorizontalFlip(),
ToTensor(),
ToTorchImage(),
NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32) if normalize else T.ToDtype(torch.float32, scale=True),
]

label_pipeline = [IntDecoder(), ToTensor(), Squeeze()]
pipelines = {"image": image_pipeline, "label": label_pipeline}

order_option = getattr(OrderOption, order.upper())
loader = Loader(
data_path,
batch_size=batch_size,
num_workers=num_workers,
order=order_option,
pipelines=pipelines,
os_cache=os_cache,
drop_last=drop_last,
)

print("[INFO] Starting streaming benchmark...")
for epoch in range(epochs):
num_samples = 0
t0 = time.perf_counter()
for data in tqdm(loader, desc=f"Epoch {epoch + 1}/{epochs}", smoothing=0, mininterval=1):
num_samples += data[0].shape[0]
elapsed = time.perf_counter() - t0
print(
f"[RESULT] Epoch {epoch + 1}: Streamed {num_samples} samples in"
f" {elapsed:.2f}s ({num_samples / elapsed:.2f} images/sec)"
)
print("[INFO] Finished streaming benchmark.")


if __name__ == "__main__":
config = get_current_config()
import argparse

parser = argparse.ArgumentParser()
config.augment_argparse(parser)
config.collect_argparse_args(parser)
config.validate(mode="stderr")
config.summary()
main()
Loading
Loading