Skip to content

Fix CSVLogger hyperparameter is logged at every write which increase latency significantly. #20594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 26, 2025
13 changes: 13 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [unreleased] - YYYY-MM-DD

### Added

### Changed

### Removed

### Fixed

- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))


## [2.5.0] - 2024-12-19

### Added
Expand Down
9 changes: 2 additions & 7 deletions src/lightning/pytorch/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,10 @@ def __init__(self, log_dir: str) -> None:
self.hparams: dict[str, Any] = {}

def log_hparams(self, params: dict[str, Any]) -> None:
"""Record hparams."""
"""Record hparams and save into files."""
self.hparams.update(params)

@override
def save(self) -> None:
"""Save recorded hparams and metrics into files."""
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
save_hparams_to_yaml(hparams_file, self.hparams)
return super().save()


class CSVLogger(Logger, FabricCSVLogger):
Expand Down Expand Up @@ -144,7 +139,7 @@ def save_dir(self) -> str:

@override
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
def log_hyperparams(self, params: Optional[Union[dict[str, Any], Namespace]] = None) -> None:
params = _convert_params(params)
self.experiment.log_hparams(params)

Expand Down
4 changes: 1 addition & 3 deletions tests/tests_pytorch/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_named_version(tmp_path):

logger = CSVLogger(save_dir=tmp_path, name=exp_name, version=expected_version)
logger.log_hyperparams({"a": 1, "b": 2})
logger.save()
assert logger.version == expected_version
assert os.listdir(tmp_path / exp_name) == [expected_version]
assert os.listdir(tmp_path / exp_name / expected_version)
Expand All @@ -85,7 +84,7 @@ def test_named_version(tmp_path):
def test_no_name(tmp_path, name):
"""Verify that None or empty name works."""
logger = CSVLogger(save_dir=tmp_path, name=name)
logger.save()
logger.log_hyperparams()
assert os.path.normpath(logger.root_dir) == str(tmp_path) # use os.path.normpath to handle trailing /
assert os.listdir(tmp_path / "version_0")

Expand Down Expand Up @@ -116,7 +115,6 @@ def test_log_hyperparams(tmp_path):
"layer": torch.nn.BatchNorm1d,
}
logger.log_hyperparams(hparams)
logger.save()

path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE)
params = load_hparams_from_yaml(path_yaml)
Expand Down
Loading