-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Add batch interval support for learning rate schedulers #21500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add batch interval support for learning rate schedulers #21500
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request combines multiple unrelated features into a single PR:
- Primary Feature (issue #21491): Adds support for scheduling learning rates on a per-batch basis by introducing a 'batch' interval option to learning rate schedulers
- Secondary Feature (issue #21448): Implements manual dataloader reloading during training via
trainer.reload_dataloaders() - Secondary Feature (no issue referenced): Adds
adapt_checkpoint_hparamshook to LightningCLI for customizing checkpoint hyperparameters
Changes:
- Added 'batch' interval option to LRSchedulerConfig with validation for valid interval values
- Implemented batch-based LR scheduler updates using global batch count (total_batch_idx) instead of per-epoch batch index
- Added
reload_dataloaders()method to Trainer for manual dataloader reloading during training - Added
adapt_checkpoint_hparams()hook to LightningCLI for checkpoint hyperparameter customization
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| src/lightning/pytorch/utilities/types.py | Added validation to LRSchedulerConfig for interval values and updated comment to mention 'batch' interval |
| src/lightning/pytorch/loops/training_epoch_loop.py | Implemented batch-interval scheduler updates using total_batch_idx; updates occur after each batch regardless of gradient accumulation |
| tests/tests_pytorch/trainer/test_lr_scheduler_batch_interval.py | Comprehensive tests for batch interval LR scheduler functionality including gradient accumulation scenarios |
| src/lightning/pytorch/trainer/trainer.py | Added reload_dataloaders() method for manual dataloader reloading during training |
| tests/tests_pytorch/trainer/test_reload_dataloaders.py | Tests for manual dataloader reloading feature |
| src/lightning/pytorch/cli.py | Added adapt_checkpoint_hparams() hook for customizing checkpoint hyperparameters before model instantiation |
| tests/tests_pytorch/test_cli.py | Tests for adapt_checkpoint_hparams hook functionality |
| docs/source-pytorch/data/access.rst | Documentation for manual dataloader reloading feature |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from lightning.pytorch.demos.boring_model import BoringModel | ||
| from torch import optim | ||
| from torch.optim.lr_scheduler import StepLR | ||
|
|
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import LightningModule is not used in the test file and should be removed. Only Trainer is used from lightning.pytorch.
|
|
||
| from lightning.pytorch.demos.boring_model import BoringModel |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imports pytest and torch are not used in the test file and should be removed. Only specific submodules from torch (torch.optim) are used.
| # after epoch is over | ||
| # after epoch is over (valid values: "epoch", "step", or "batch") | ||
| interval: str = "epoch" | ||
| # every epoch/batch |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on line 87 says "every epoch/batch" but should be updated to "every epoch/step/batch" to reflect all three interval types ("epoch", "step", "batch").
| # every epoch/batch | |
| # every epoch/step/batch |
| from torch import optim | ||
| from torch.optim.lr_scheduler import StepLR | ||
|
|
||
| from lightning.pytorch import Trainer |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import path is incorrect. It should be from lightning.pytorch.demos.boring_classes import BoringModel instead of from lightning.pytorch.demos.boring_model import BoringModel. The module boring_model does not exist; BoringModel is defined in boring_classes.
|
|
||
| from lightning.pytorch.demos.boring_model import BoringModel | ||
| from torch import optim | ||
| from torch.optim.lr_scheduler import StepLR |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imports ExponentialLR and LambdaLR from torch.optim.lr_scheduler are not used in the test file and should be removed.
f97374e to
60cb303
Compare
…ameter loading Fixes Lightning-AI#21255 This commit adds the adapt_checkpoint_hparams() public method to LightningCLI, allowing users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This is particularly useful when using checkpoints from a TrainingModule with a different InferenceModule class that has different __init__ parameters. Problem: When loading a checkpoint trained with TrainingModule(lr=1e-3) into an InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail during instantiation because it tries to pass all checkpoint hyperparameters to the new module class. Solution: Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path() after loading checkpoint hyperparameters but before applying them. Users can override this method to: - Remove training-specific hyperparameters (e.g., lr, weight_decay) - Modify _class_path for subclass mode - Transform hyperparameter names/values - Completely disable checkpoint hyperparameters by returning {} Example usage: class MyCLI(LightningCLI): def adapt_checkpoint_hparams(self, checkpoint_hparams): checkpoint_hparams.pop('lr', None) checkpoint_hparams.pop('weight_decay', None) return checkpoint_hparams This approach is preferable to: - Disabling checkpoint loading entirely (loses valuable hyperparameter info) - Adding CLI arguments (deviates from Trainer parameter pattern) - Modifying private methods (breaks encapsulation) The hook provides maximum flexibility while maintaining backward compatibility (default implementation returns hyperparameters unchanged).
for more information, see https://pre-commit.ci
…ook and add tests - Update adapt_checkpoint_hparams signature to include subcommand parameter allowing context-aware customization of checkpoint hyperparameters - Change type annotations to use lowercase dict (Python 3.9+ style) - Update docstring with subcommand parameter documentation - Add example showing conditional logic based on subcommand - Add comprehensive unit tests: - test_adapt_checkpoint_hparams_hook: Tests that hook is called and modifications applied - test_adapt_checkpoint_hparams_hook_empty_dict: Tests disabling checkpoint hparams loading - Tests cover both regular and subclass modes
for more information, see https://pre-commit.ci
- Split method signature across multiple lines to stay within 120 char limit - Improves code readability in documentation example
… size mismatch in tests
for more information, see https://pre-commit.ci
Removed redundant method implementations since BoringModel provides them.
The test was asserting hidden_dim==3 but only passing out_dim=3. Since hidden_dim defaults to 16 and there's no argument linking, the assertion failed. Now we explicitly pass --model.hidden_dim=6.
This commit implements support for scheduling learning rates on a per-batch basis, addressing issue Lightning-AI#21491. Previously, PyTorch Lightning only supported epoch-based ('epoch') and optimizer-step-based ('step') intervals. Key Changes: 1. TrainEpochLoop._update_learning_rates() now supports 'batch' interval: - Uses global batch count (total_batch_idx) for scheduling decisions - Updates learning rates on every training batch - Skips the gradient accumulation check that applies to 'step' interval 2. LRSchedulerConfig now validates interval values: - Added __post_init__() method to enforce valid intervals - Raises ValueError for invalid interval specifications - Supports 'epoch', 'step', or 'batch' 3. Comprehensive test suite added: - Tests batch interval basic functionality - Tests interaction with gradient accumulation - Tests frequency parameter handling - Tests mixed interval types - Tests initialization and edge cases
for more information, see https://pre-commit.ci
Previously, batch-interval schedulers were updated after the optimizer step, which meant LR changes were not visible to the on_train_batch_start hook. This caused the LR tracking tests to fail because they recorded LR values that were from the previous batch's scheduler update. Now batch-interval schedulers are updated at the beginning of each batch, ensuring LR changes are visible to all hooks that occur during the batch.
for more information, see https://pre-commit.ci
bb1d77d to
2eeed9e
Compare
|
Rebased on latest upstream/master (0a0f061) to fix the 54 failing CI checks. The branch was significantly behind master which caused all the install-pkg and pl-cpu tests to fail. CI should now pass. ✅ |
Implement support for scheduling learning rates on a per-batch basis, resolving issue #21491.
Key Changes:
Problem Addressed:
When gradient accumulation is used with varying factors, optimizer-step-based scheduling becomes inconsistent. Batch-based scheduling provides a cleaner alternative.
Resolves #21491
📚 Documentation preview 📚: https://pytorch-lightning--21500.org.readthedocs.build/en/21500/