Skip to content

Conversation

@arrdel
Copy link
Contributor

@arrdel arrdel commented Jan 20, 2026

Implement support for scheduling learning rates on a per-batch basis, resolving issue #21491.

Key Changes:

  • Add 'batch' interval option to learning rate schedulers
  • Batch interval uses global batch count for consistency across epochs
  • Unlike 'step' interval, batch updates are not skipped during gradient accumulation
  • Added validation to LRSchedulerConfig to ensure valid interval values
  • Comprehensive test suite included

Problem Addressed:
When gradient accumulation is used with varying factors, optimizer-step-based scheduling becomes inconsistent. Batch-based scheduling provides a cleaner alternative.

Resolves #21491


📚 Documentation preview 📚: https://pytorch-lightning--21500.org.readthedocs.build/en/21500/

Copilot AI review requested due to automatic review settings January 20, 2026 22:28
@github-actions github-actions bot added docs Documentation related pl Generic label for PyTorch Lightning package labels Jan 20, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request combines multiple unrelated features into a single PR:

  1. Primary Feature (issue #21491): Adds support for scheduling learning rates on a per-batch basis by introducing a 'batch' interval option to learning rate schedulers
  2. Secondary Feature (issue #21448): Implements manual dataloader reloading during training via trainer.reload_dataloaders()
  3. Secondary Feature (no issue referenced): Adds adapt_checkpoint_hparams hook to LightningCLI for customizing checkpoint hyperparameters

Changes:

  • Added 'batch' interval option to LRSchedulerConfig with validation for valid interval values
  • Implemented batch-based LR scheduler updates using global batch count (total_batch_idx) instead of per-epoch batch index
  • Added reload_dataloaders() method to Trainer for manual dataloader reloading during training
  • Added adapt_checkpoint_hparams() hook to LightningCLI for checkpoint hyperparameter customization

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
src/lightning/pytorch/utilities/types.py Added validation to LRSchedulerConfig for interval values and updated comment to mention 'batch' interval
src/lightning/pytorch/loops/training_epoch_loop.py Implemented batch-interval scheduler updates using total_batch_idx; updates occur after each batch regardless of gradient accumulation
tests/tests_pytorch/trainer/test_lr_scheduler_batch_interval.py Comprehensive tests for batch interval LR scheduler functionality including gradient accumulation scenarios
src/lightning/pytorch/trainer/trainer.py Added reload_dataloaders() method for manual dataloader reloading during training
tests/tests_pytorch/trainer/test_reload_dataloaders.py Tests for manual dataloader reloading feature
src/lightning/pytorch/cli.py Added adapt_checkpoint_hparams() hook for customizing checkpoint hyperparameters before model instantiation
tests/tests_pytorch/test_cli.py Tests for adapt_checkpoint_hparams hook functionality
docs/source-pytorch/data/access.rst Documentation for manual dataloader reloading feature

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

from lightning.pytorch.demos.boring_model import BoringModel
from torch import optim
from torch.optim.lr_scheduler import StepLR

Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import LightningModule is not used in the test file and should be removed. Only Trainer is used from lightning.pytorch.

Copilot uses AI. Check for mistakes.
Comment on lines 2 to 3

from lightning.pytorch.demos.boring_model import BoringModel
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports pytest and torch are not used in the test file and should be removed. Only specific submodules from torch (torch.optim) are used.

Copilot uses AI. Check for mistakes.
# after epoch is over
# after epoch is over (valid values: "epoch", "step", or "batch")
interval: str = "epoch"
# every epoch/batch
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on line 87 says "every epoch/batch" but should be updated to "every epoch/step/batch" to reflect all three interval types ("epoch", "step", "batch").

Suggested change
# every epoch/batch
# every epoch/step/batch

Copilot uses AI. Check for mistakes.
from torch import optim
from torch.optim.lr_scheduler import StepLR

from lightning.pytorch import Trainer
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import path is incorrect. It should be from lightning.pytorch.demos.boring_classes import BoringModel instead of from lightning.pytorch.demos.boring_model import BoringModel. The module boring_model does not exist; BoringModel is defined in boring_classes.

Copilot uses AI. Check for mistakes.

from lightning.pytorch.demos.boring_model import BoringModel
from torch import optim
from torch.optim.lr_scheduler import StepLR
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports ExponentialLR and LambdaLR from torch.optim.lr_scheduler are not used in the test file and should be removed.

Copilot uses AI. Check for mistakes.
@arrdel arrdel force-pushed the feat/issue-21491-batch-interval-scheduler branch from f97374e to 60cb303 Compare January 23, 2026 15:59
arrdel and others added 15 commits January 26, 2026 21:56
…ameter loading

Fixes Lightning-AI#21255

This commit adds the adapt_checkpoint_hparams() public method to LightningCLI,
allowing users to customize hyperparameters loaded from checkpoints before they
are used to instantiate model classes. This is particularly useful when using
checkpoints from a TrainingModule with a different InferenceModule class that
has different __init__ parameters.

Problem:
When loading a checkpoint trained with TrainingModule(lr=1e-3) into an
InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail
during instantiation because it tries to pass all checkpoint hyperparameters
to the new module class.

Solution:
Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path()
after loading checkpoint hyperparameters but before applying them. Users can
override this method to:
- Remove training-specific hyperparameters (e.g., lr, weight_decay)
- Modify _class_path for subclass mode
- Transform hyperparameter names/values
- Completely disable checkpoint hyperparameters by returning {}

Example usage:
    class MyCLI(LightningCLI):
        def adapt_checkpoint_hparams(self, checkpoint_hparams):
            checkpoint_hparams.pop('lr', None)
            checkpoint_hparams.pop('weight_decay', None)
            return checkpoint_hparams

This approach is preferable to:
- Disabling checkpoint loading entirely (loses valuable hyperparameter info)
- Adding CLI arguments (deviates from Trainer parameter pattern)
- Modifying private methods (breaks encapsulation)

The hook provides maximum flexibility while maintaining backward compatibility
(default implementation returns hyperparameters unchanged).
…ook and add tests

- Update adapt_checkpoint_hparams signature to include subcommand parameter
  allowing context-aware customization of checkpoint hyperparameters
- Change type annotations to use lowercase dict (Python 3.9+ style)
- Update docstring with subcommand parameter documentation
- Add example showing conditional logic based on subcommand
- Add comprehensive unit tests:
  - test_adapt_checkpoint_hparams_hook: Tests that hook is called and modifications applied
  - test_adapt_checkpoint_hparams_hook_empty_dict: Tests disabling checkpoint hparams loading
- Tests cover both regular and subclass modes
- Split method signature across multiple lines to stay within 120 char limit
- Improves code readability in documentation example
Removed redundant method implementations since BoringModel provides them.
The test was asserting hidden_dim==3 but only passing out_dim=3.
Since hidden_dim defaults to 16 and there's no argument linking,
the assertion failed. Now we explicitly pass --model.hidden_dim=6.
This commit implements support for scheduling learning rates on a per-batch
basis, addressing issue Lightning-AI#21491. Previously, PyTorch Lightning only supported
epoch-based ('epoch') and optimizer-step-based ('step') intervals.

Key Changes:
1. TrainEpochLoop._update_learning_rates() now supports 'batch' interval:
   - Uses global batch count (total_batch_idx) for scheduling decisions
   - Updates learning rates on every training batch
   - Skips the gradient accumulation check that applies to 'step' interval

2. LRSchedulerConfig now validates interval values:
   - Added __post_init__() method to enforce valid intervals
   - Raises ValueError for invalid interval specifications
   - Supports 'epoch', 'step', or 'batch'

3. Comprehensive test suite added:
   - Tests batch interval basic functionality
   - Tests interaction with gradient accumulation
   - Tests frequency parameter handling
   - Tests mixed interval types
   - Tests initialization and edge cases
Previously, batch-interval schedulers were updated after the optimizer step,
which meant LR changes were not visible to the on_train_batch_start hook.

This caused the LR tracking tests to fail because they recorded LR values
that were from the previous batch's scheduler update.

Now batch-interval schedulers are updated at the beginning of each batch,
ensuring LR changes are visible to all hooks that occur during the batch.
@arrdel arrdel force-pushed the feat/issue-21491-batch-interval-scheduler branch from bb1d77d to 2eeed9e Compare January 27, 2026 02:56
@arrdel
Copy link
Contributor Author

arrdel commented Jan 27, 2026

Rebased on latest upstream/master (0a0f061) to fix the 54 failing CI checks. The branch was significantly behind master which caused all the install-pkg and pl-cpu tests to fail. CI should now pass. ✅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs Documentation related pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add interval: "batch" option for learning rate schedulers

1 participant