Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
958f95e
feat: introduce `logger_map` property.
GdoongMathew Oct 19, 2025
711fb4f
revert trainer.logger change.
GdoongMathew Oct 19, 2025
791dbdc
add tests.
GdoongMathew Oct 19, 2025
db51529
add test.
GdoongMathew Oct 19, 2025
7cb1382
add test.
GdoongMathew Oct 19, 2025
b1ea7d3
fix pylint
GdoongMathew Oct 19, 2025
6bbc98d
fix pylint
GdoongMathew Oct 19, 2025
56ea5e8
add test.
GdoongMathew Oct 19, 2025
d031985
refactor loggers setter.
GdoongMathew Oct 19, 2025
e8d3695
fix pylint.
GdoongMathew Oct 19, 2025
01aaa41
_ListMap integration.
GdoongMathew Oct 21, 2025
e5a38ed
Merge branch 'master' into feat/logger_dict
GdoongMathew Oct 22, 2025
022fa92
fix: fix unittests.
GdoongMathew Oct 22, 2025
c309642
fix pylint.
GdoongMathew Oct 22, 2025
67b888f
add reverse impl.
GdoongMathew Oct 22, 2025
b3a3a70
Merge remote-tracking branch 'origin/feat/logger_dict' into feat/logg…
GdoongMathew Oct 22, 2025
3e9e398
implement list methods.
GdoongMathew Oct 22, 2025
fa83ab2
implement get method.
GdoongMathew Oct 22, 2025
085f167
adding notes.
GdoongMathew Oct 22, 2025
0d8f725
refactor
GdoongMathew Oct 22, 2025
a393ae3
Merge branch 'master' into feat/logger_dict
GdoongMathew Oct 22, 2025
0e14e09
docs
GdoongMathew Oct 23, 2025
2d9f419
test: add additional unittests.
GdoongMathew Oct 23, 2025
b78daea
fix: fix delete implementation.
GdoongMathew Oct 23, 2025
c371b20
docs: fix doctest.
GdoongMathew Oct 23, 2025
41f4311
add unittest case.
GdoongMathew Oct 23, 2025
172ceb3
fix: fix mypy
GdoongMathew Oct 23, 2025
fffe03b
test
GdoongMathew Oct 23, 2025
c45fa9b
fix mypy
GdoongMathew Oct 24, 2025
ec39fe5
fix mypy
GdoongMathew Oct 24, 2025
69a2ef3
Merge remote-tracking branch 'origin/feat/logger_dict' into feat/logg…
GdoongMathew Oct 24, 2025
a2709c2
ref: refactor __delitem__
GdoongMathew Oct 24, 2025
9d0d39d
fix: mypy
GdoongMathew Oct 24, 2025
01b9247
Merge branch 'master' into feat/logger_dict
GdoongMathew Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 243 additions & 1 deletion src/lightning/pytorch/loggers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
# limitations under the License.
"""Utilities for loggers."""

from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView
from pathlib import Path
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union

from torch import Tensor
from typing_extensions import Self, overload

import lightning.pytorch as pl
from lightning.pytorch.callbacks import Checkpoint

if TYPE_CHECKING:
from _typeshed import SupportsRichComparison


def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]:
if len(loggers) == 1:
Expand Down Expand Up @@ -100,3 +105,240 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
logger.log_hyperparams(hparams_initial)
logger.log_graph(pl_module)
logger.save()


_T = TypeVar("_T")
_PT = TypeVar("_PT")


class _ListMap(list[_T]):
"""A hybrid container allowing both index and name access.

This class extends the built-in list to provide dictionary-like access to its elements
using string keys. It maintains an internal mapping of string keys to list indices,
allowing users to retrieve, set, and delete elements by their associated names.

Args:
__iterable (Union[Iterable[_T], Mapping[str, _T]], optional): An iterable of objects or a mapping
of string keys to __iterable to initialize the container.

Raises:
TypeError: If a Mapping is provided and any of its keys are not of type str.

Example:
>>> listmap = _ListMap({'obj1': 1, 'obj2': 2})
>>> listmap['obj1'] # Access by name
1
>>> listmap[0] # Access by index
1
>>> listmap['obj2'] = 3 # Set by name
>>> listmap[1] # Now returns obj3
3
>>> listmap.append(4) # Append by index
>>> listmap[2]
4

"""

def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None):
_dict: dict[str, int]
if isinstance(__iterable, Mapping):
# super inits list with values
if any(not isinstance(x, str) for x in __iterable):
raise TypeError("When providing a Mapping, all keys must be of type str.")
super().__init__(__iterable.values())
_dict = dict(zip(__iterable.keys(), range(len(__iterable))))
else:
default_dict = {}
if isinstance(__iterable, _ListMap):
default_dict = __iterable._dict.copy()
super().__init__(() if __iterable is None else __iterable)
_dict: dict = default_dict
self._dict = _dict

def __eq__(self, other: Any) -> bool:
list_eq = list.__eq__(self, other)
if isinstance(other, _ListMap):
return list_eq and self._dict == other._dict
return list_eq

def copy(self) -> "_ListMap":
new_listmap = _ListMap(self)
new_listmap._dict = self._dict.copy()
return new_listmap

def extend(self, __iterable: Iterable[_T]) -> None:
if isinstance(__iterable, _ListMap):
offset = len(self)
for key, idx in __iterable._dict.items():
self._dict[key] = idx + offset
super().extend(__iterable)

@overload
def pop(self, key: SupportsIndex = -1, /) -> _T: ...

@overload
def pop(self, key: Union[str, SupportsIndex], default: _T, /) -> _T: ...

@overload
def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ...

def pop(self, key=-1, default=None):
if isinstance(key, int):
ret = super().pop(key)
for str_key, idx in list(self._dict.items()):
if idx == key:
self._dict.pop(str_key)
elif idx > key:
self._dict[str_key] = idx - 1
return ret
if isinstance(key, str):
if key not in self._dict:
return default
return self.pop(self._dict[key])
raise TypeError("Key must be int or str")

def insert(self, index: SupportsIndex, __object: _T) -> None:
for key, idx in self._dict.items():
if idx >= index:
self._dict[key] = idx + 1
super().insert(index, __object)

def remove(self, __object: _T) -> None:
idx = self.index(__object)
name = None
for key, val in self._dict.items():
if val == idx:
name = key
elif val > idx:
self._dict[key] = val - 1
if name:
self._dict.pop(name, None)
super().remove(__object)

def sort(
self,
*,
key: Optional[Callable[[_T], "SupportsRichComparison"]] = None,
reverse: bool = False,
) -> None:
# Create a mapping from item to its name(s)
item_to_names: dict[_T, list[str]] = {}
for name, idx in self._dict.items():
item = self[idx]
item_to_names.setdefault(item, []).append(name)
# Sort the list
super().sort(key=key, reverse=reverse)
# Update _dict with new indices
new_dict: dict[str, int] = {}
for idx, item in enumerate(self):
if item in item_to_names:
for name in item_to_names[item]:
new_dict[name] = idx
self._dict = new_dict

@overload
def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ...

@overload
def __getitem__(self, key: slice, /) -> list[_T]: ...

def __getitem__(self, key):
if isinstance(key, str):
return self[self._dict[key]]
return list.__getitem__(self, key)

def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]":
new_listmap = self.copy()
new_listmap += other
return new_listmap

def __iadd__(self, other: Union[list[_T], Self]) -> Self:
if isinstance(other, _ListMap):
offset = len(self)
for key, idx in other._dict.items():
# notes: if there are duplicate keys, the ones from other will overwrite self
self._dict[key] = idx + offset

return super().__iadd__(other)

@overload
def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ...

@overload
def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ...

def __setitem__(self, key, value):
if isinstance(key, (int, slice)):
# replace element by index
return super().__setitem__(key, value)
if isinstance(key, str):
# replace or insert by name
if key in self._dict:
super().__setitem__(self._dict[key], value)
else:
self.append(value)
self._dict[key] = len(self) - 1
return None
raise TypeError("Key must be int or str")

def __contains__(self, item: Union[object, str]) -> bool:
if isinstance(item, str):
return item in self._dict
return super().__contains__(item)

# --- Dict-like interface ---

def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None:
index: Union[SupportsIndex, slice]
if isinstance(key, str):
if key not in self._dict:
raise KeyError(f"Key '{key}' not found.")
index = self._dict[key]
else:
index = key

if isinstance(index, (int, slice)):
super().__delitem__(index)
for _key in index.indices(len(self)) if isinstance(index, slice) else [index]:
# update indices in the dict
for str_key, idx in list(self._dict.items()):
if idx == _key:
self._dict.pop(str_key)
elif idx > _key:
self._dict[str_key] = idx - 1
else:
raise TypeError("Key must be int or str")

def keys(self) -> KeysView[str]:
return self._dict.keys()

def values(self) -> ValuesView[_T]:
return {k: self[v] for k, v in self._dict.items()}.values()

def items(self) -> ItemsView[str, _T]:
return {k: self[v] for k, v in self._dict.items()}.items()

@overload
def get(self, __key: str) -> Optional[_T]: ...

@overload
def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ...

def get(self, __key, default=None):
if __key in self._dict:
return self[self._dict[__key]]
return default

def __repr__(self) -> str:
ret = super().__repr__()
return f"_ListMap({ret}, keys={list(self._dict.keys())})"

def reverse(self) -> None:
for key, idx in self._dict.items():
self._dict[key] = len(self) - 1 - idx
list.reverse(self)

def clear(self) -> None:
self._dict.clear()
list.clear(self)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from typing import Any, Optional, Union

from lightning_utilities.core.apply_func import apply_to_collection
Expand Down Expand Up @@ -82,6 +82,8 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non
)
logger_ = CSVLogger(save_dir=self.trainer.default_root_dir) # type: ignore[assignment]
self.trainer.loggers = [logger_]
elif isinstance(logger, Mapping):
self.trainer.loggers = logger
elif isinstance(logger, Iterable):
self.trainer.loggers = list(logger)
else:
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import logging
import math
import os
from collections.abc import Generator, Iterable
from collections.abc import Generator, Iterable, Mapping
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Optional, Union
Expand All @@ -43,7 +43,7 @@
from lightning.pytorch.loggers import Logger
from lightning.pytorch.loggers.csv_logs import CSVLogger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.utilities import _log_hyperparams
from lightning.pytorch.loggers.utilities import _ListMap, _log_hyperparams
from lightning.pytorch.loops import _PredictionLoop, _TrainingEpochLoop
from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop
from lightning.pytorch.loops.fit_loop import _FitLoop
Expand Down Expand Up @@ -494,7 +494,7 @@ def __init__(
setup._init_profiler(self, profiler)

# init logger flags
self._loggers: list[Logger]
self._loggers: _ListMap[Logger]
self._logger_connector.on_trainer_init(logger, log_every_n_steps)

# init debugging flags
Expand Down Expand Up @@ -1631,7 +1631,7 @@ def logger(self, logger: Optional[Logger]) -> None:
self.loggers = [logger]

@property
def loggers(self) -> list[Logger]:
def loggers(self) -> _ListMap[Logger]:
"""The list of :class:`~lightning.pytorch.loggers.logger.Logger` used.

.. code-block:: python
Expand All @@ -1643,8 +1643,8 @@ def loggers(self) -> list[Logger]:
return self._loggers

@loggers.setter
def loggers(self, loggers: Optional[list[Logger]]) -> None:
self._loggers = loggers if loggers else []
def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger], _ListMap[Logger]]]) -> None:
self._loggers = _ListMap(loggers)

@property
def callback_metrics(self) -> _OUT_DICT:
Expand Down
Loading
Loading