Skip to content

Conversation

@GeorgePearse
Copy link
Owner

Summary

Integrated PyTorch Lightning Fabric to enable scalable multi-GPU training and simplify device management across the codebase.

Key Changes

  • Dependencies: Added lightning>=2.0.0 to pyproject.toml.
  • Refactoring:
    • Migrated examples/train.py, examples/train_video.py, examples/train_elic_cifar10.py, and tinify/cli/train.py to use Fabric.
    • Removed manual device placement (.to(device)) and CustomDataParallel wrapper.
    • Updated training loops to use fabric.backward() and fabric.clip_gradients().
  • CLI: Added --accelerator, --devices, --strategy, and --precision arguments to training scripts for flexible hardware configuration.

Benefits

  • Seamless support for multi-GPU training (DDP, FSDP, etc.) without code modifications.
  • Simplified codebase by removing boilerplate device management code.
  • Improved mixed-precision training support via Fabric.

Runs pytest against Python 3.9-3.12 matrix on:
- Pull requests to main
- Pushes to main

Excludes slow tests (pretrained model downloads) for fast CI.
- Added tests/test_train_fabric.py
- Added tests/test_cli_train.py
- Updated tests/test_train.py to use CPU explicitly
- Added VideoRateDistortionLoss to tinify/losses for correct video training support in CLI
@GeorgePearse GeorgePearse merged commit fdeaa81 into main Nov 21, 2025
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants