Skip to content

Commit

Permalink
Fix schedule reset logic in pytorch profiler (#10837)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Dec 2, 2021
1 parent 7d534bd commit f26f637
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 19 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))


- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837))


-


-


## [1.5.4] - 2021-11-30

### Fixed
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union
from typing import Any, List, Sequence

from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader
Expand All @@ -32,7 +32,8 @@ def __init__(self):
self.epoch_loop = EvaluationEpochLoop()

self._results = ResultCollection(training=False)
self._max_batches: Optional[Union[int, Sequence[int]]] = None
self._outputs: List[EPOCH_OUTPUT] = []
self._max_batches: List[int] = []
self._has_run: bool = False

@property
Expand Down Expand Up @@ -147,7 +148,7 @@ def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()

def _get_max_batches(self) -> List[Union[int, float]]:
def _get_max_batches(self) -> List[int]:
"""Returns the max number of batches for each dataloader."""
if self.trainer.testing:
max_batches = self.trainer.num_test_batches
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def num_dataloaders(self) -> int:
@property
def max_batches(self) -> List[int]:
"""The max number of batches this loop will run for each dataloader."""
max_batches = self.trainer.num_predict_batches
if isinstance(max_batches, int):
max_batches = [max_batches] * len(self.dataloaders)
return max_batches
return self.trainer.num_predict_batches

@property
def dataloaders(self) -> Sequence[DataLoader]:
Expand Down
25 changes: 20 additions & 5 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,24 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph
self._profiler_kwargs["with_stack"] = with_stack

@property
def _total_steps(self) -> int:
trainer = self._lightning_module.trainer
if self._schedule.is_training:
return trainer.num_training_batches
if self._schedule._current_action == "validation_step":
return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches)
if self._schedule._current_action == "test_step":
return sum(trainer.num_test_batches)
if self._schedule._current_action == "predict_step":
return sum(trainer.num_predict_batches)

def _should_override_schedule(self) -> bool:
return (self._lightning_module is not None and self._lightning_module.trainer.limit_train_batches < 5) and (
self._schedule is not None and self._schedule._schedule == self._default_schedule()
return (
self._lightning_module is not None
and self._schedule is not None
and self._total_steps < 5
and self._schedule._schedule == self._default_schedule()
)

@staticmethod
Expand Down Expand Up @@ -410,6 +425,9 @@ def stop(self, action_name: str) -> None:
action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX)
):

if self._schedule is not None:
self._schedule.pre_step(action_name)

# the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`.
# otherwise, this will raise a `segmentation fault`.
if self._should_override_schedule():
Expand All @@ -420,9 +438,6 @@ def stop(self, action_name: str) -> None:
self._schedule = None
self.profiler.schedule = torch.profiler.profiler._default_schedule_fn

if self._schedule is not None:
self._schedule.pre_step(action_name)

def on_trace_ready(profiler):
if self.dirpath is not None:
if self._export_to_chrome:
Expand Down
17 changes: 11 additions & 6 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,18 @@ class TrainerDataLoadingMixin(ABC):
val_check_interval: float
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: Union[int, float]
val_check_batch: float
val_dataloaders: Optional[List[DataLoader]]
num_val_batches: List[Union[int, float]]
test_dataloaders: Optional[List[DataLoader]]
num_test_batches: List[Union[int, float]]
limit_train_batches: Union[int, float]
num_training_batches: int
val_check_batch: float
val_dataloaders: List[DataLoader]
limit_val_batches: Union[int, float]
num_val_batches: List[int]
test_dataloaders: List[DataLoader]
limit_test_batches: Union[int, float]
num_test_batches: List[int]
predict_dataloaders: List[DataLoader]
limit_predict_batches: Union[int, float]
num_predict_batches: List[int]
log_every_n_steps: int
overfit_batches: Union[int, float]
distributed_sampler_kwargs: dict
Expand Down
30 changes: 29 additions & 1 deletion tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
Expand Down Expand Up @@ -523,3 +523,31 @@ def test_trainer_profiler_incorrect_str_arg():
match=r"When passing string value for the `profiler` parameter of `Trainer`, it can only be one of.*",
):
Trainer(profiler="unknown_profiler")


@pytest.mark.skipif(not _KINETO_AVAILABLE, reason="Requires PyTorch Profiler Kineto")
@pytest.mark.parametrize(
["trainer_config", "trainer_fn"],
[
({"limit_train_batches": 4, "limit_val_batches": 7}, "fit"),
({"limit_train_batches": 7, "limit_val_batches": 4, "num_sanity_val_steps": 0}, "fit"),
(
{
"limit_train_batches": 7,
"limit_val_batches": 2,
},
"fit",
),
({"limit_val_batches": 4}, "validate"),
({"limit_test_batches": 4}, "test"),
({"limit_predict_batches": 4}, "predict"),
],
)
def test_pytorch_profiler_raises_warning_for_limited_steps(tmpdir, trainer_config, trainer_fn):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", max_epochs=1, **trainer_config)
warning_cache.clear()
with pytest.warns(UserWarning, match="not enough steps to properly record traces"):
getattr(trainer, trainer_fn)(model)
assert trainer.profiler._schedule is None
warning_cache.clear()

0 comments on commit f26f637

Please sign in to comment.