Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy errors attributed to pytorch_lightning.core.datamodule #13693

Merged
merged 43 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
43c2ad2
remove module from pyproject.toml for ci code-checks
jxtngx Jul 17, 2022
9686043
update
jxtngx Jul 17, 2022
bd6fb0f
update return type
jxtngx Jul 17, 2022
5db82b4
update for LightningDataModule codeq
jxtngx Jul 17, 2022
cba2910
Merge branch 'Lightning-AI:master' into codeq/datamodule
jxtngx Jul 23, 2022
de0f36e
Merge branch 'master' into codeq/datamodule
jxtngx Jul 23, 2022
16a279f
Merge branch 'master' into codeq/datamodule
jxtngx Jul 25, 2022
90aa6c9
Merge branch 'master' into codeq/datamodule
jxtngx Jul 26, 2022
cbae81b
Merge branch 'master' into codeq/datamodule
jxtngx Jul 28, 2022
089981a
update
jxtngx Jul 28, 2022
6087294
Merge branch 'master' into codeq/datamodule
jxtngx Jul 28, 2022
049d1c1
update
jxtngx Jul 29, 2022
2f6e162
update
jxtngx Jul 29, 2022
ec292c3
Merge branch 'master' into codeq/datamodule
jxtngx Aug 1, 2022
b189f22
Merge branch 'master' into codeq/datamodule
jxtngx Aug 1, 2022
2d8adcc
Merge branch 'master' into codeq/datamodule
jxtngx Aug 1, 2022
517576d
update
jxtngx Aug 2, 2022
ae32bc1
Merge branch 'master' into codeq/datamodule
jxtngx Aug 2, 2022
020872d
commit suggestion
jxtngx Aug 3, 2022
850263e
update
jxtngx Aug 3, 2022
c9edb9d
Merge branch 'master' into codeq/datamodule
jxtngx Aug 3, 2022
6a45961
Merge branch 'master' into codeq/datamodule
jxtngx Aug 5, 2022
f3775d8
Merge branch 'master' into codeq/datamodule
jxtngx Aug 5, 2022
eb8ecb9
resolve merge conflicts
jxtngx Aug 8, 2022
3ae9e36
update
jxtngx Aug 8, 2022
6e2062e
Merge branch 'master' into codeq/datamodule
jxtngx Aug 9, 2022
92197a0
Merge branch 'master' into codeq/datamodule
jxtngx Aug 9, 2022
2585d0a
self review
rohitgr7 Aug 9, 2022
ab2cf24
fix
rohitgr7 Aug 9, 2022
8af819b
Merge branch 'master' into codeq/datamodule
jxtngx Aug 10, 2022
cc55ab9
update
jxtngx Aug 10, 2022
a08f574
Merge branch 'master' into codeq/datamodule
jxtngx Aug 12, 2022
7d56668
Merge branch 'master' into codeq/datamodule
carmocca Aug 22, 2022
28fae30
fix introduced mypy error
Aug 22, 2022
e0d25ec
fix docstring
Aug 22, 2022
402d025
Merge branch 'master' into codeq/datamodule
otaj Aug 23, 2022
e7cef8a
Merge branch 'master' into codeq/datamodule
otaj Aug 24, 2022
d53275b
Merge branch 'master' into codeq/datamodule
otaj Aug 25, 2022
5e5629c
Merge branch 'master' into codeq/datamodule
rohitgr7 Aug 25, 2022
4053b7e
Merge branch 'master' into codeq/datamodule
otaj Aug 26, 2022
3ee9dde
merge master
Aug 26, 2022
16528aa
merge master
Aug 26, 2022
cc11363
Merge branch 'master' into codeq/datamodule
Borda Aug 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.callbacks.stochastic_weight_avg",
jxtngx marked this conversation as resolved.
Show resolved Hide resolved
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.decorators",
"pytorch_lightning.core.mixins.device_dtype_mixin",
"pytorch_lightning.core.module",
Expand Down
34 changes: 18 additions & 16 deletions src/pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""LightningDataModule for loading DataLoaders with ease."""
from argparse import ArgumentParser, Namespace
from argparse import _ArgumentGroup, ArgumentParser, Namespace
jxtngx marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union

from torch.utils.data import DataLoader, Dataset, IterableDataset
Expand All @@ -23,6 +23,8 @@
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.types import _PATH

ADD_ARGPARSE_RETURN = Union[ArgumentParser, Union[_ArgumentGroup, ArgumentParser]]
otaj marked this conversation as resolved.
Show resolved Hide resolved


class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
"""A DataModule standardizes the training, val, test splits, data preparation and transforms. The main
Expand Down Expand Up @@ -53,7 +55,7 @@ def teardown(self):
# called on every process in DDP
"""

name: str = ...
name: Union[str, Any] = ...
otaj marked this conversation as resolved.
Show resolved Hide resolved
CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters"
CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name"
CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type"
Expand All @@ -64,7 +66,7 @@ def __init__(self) -> None:
self.trainer = None

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> ADD_ARGPARSE_RETURN:
"""Extends existing argparse by default `LightningDataModule` attributes.

Example::
Expand All @@ -75,7 +77,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentP
return add_argparse_args(cls, parent_parser, **kwargs)

@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> ArgumentParser:
jxtngx marked this conversation as resolved.
Show resolved Hide resolved
"""Create an instance from CLI arguments.

Args:
Expand Down Expand Up @@ -109,7 +111,7 @@ def from_datasets(
predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
batch_size: int = 1,
num_workers: int = 0,
):
) -> "LightningDataModule":
r"""
Create an instance from torch.utils.data.Dataset.

Expand All @@ -124,41 +126,41 @@ def from_datasets(

"""

def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
def dataloader(ds: Any, shuffle: bool = False) -> DataLoader:
otaj marked this conversation as resolved.
Show resolved Hide resolved
shuffle &= not isinstance(ds, IterableDataset)
return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)

def train_dataloader():
def train_dataloader() -> Union[Dict, List, DataLoader]:
otaj marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(train_dataset, Mapping):
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
if isinstance(train_dataset, Sequence):
return [dataloader(ds, shuffle=True) for ds in train_dataset]
return dataloader(train_dataset, shuffle=True)

def val_dataloader():
def val_dataloader() -> Union[List, DataLoader]:
otaj marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(val_dataset, Sequence):
return [dataloader(ds) for ds in val_dataset]
return dataloader(val_dataset)

def test_dataloader():
def test_dataloader() -> Union[List, DataLoader]:
otaj marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(test_dataset, Sequence):
return [dataloader(ds) for ds in test_dataset]
return dataloader(test_dataset)

def predict_dataloader():
def predict_dataloader() -> Union[List, DataLoader]:
otaj marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(predict_dataset, Sequence):
return [dataloader(ds) for ds in predict_dataset]
return dataloader(predict_dataset)

datamodule = cls()
if train_dataset is not None:
datamodule.train_dataloader = train_dataloader
datamodule.train_dataloader = train_dataloader # type: ignore[assignment]
otaj marked this conversation as resolved.
Show resolved Hide resolved
otaj marked this conversation as resolved.
Show resolved Hide resolved
if val_dataset is not None:
datamodule.val_dataloader = val_dataloader
datamodule.val_dataloader = val_dataloader # type: ignore[assignment]
otaj marked this conversation as resolved.
Show resolved Hide resolved
otaj marked this conversation as resolved.
Show resolved Hide resolved
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader
datamodule.test_dataloader = test_dataloader # type: ignore[assignment]
otaj marked this conversation as resolved.
Show resolved Hide resolved
otaj marked this conversation as resolved.
Show resolved Hide resolved
if predict_dataset is not None:
datamodule.predict_dataloader = predict_dataloader
datamodule.predict_dataloader = predict_dataloader # type: ignore[assignment]
otaj marked this conversation as resolved.
Show resolved Hide resolved
jxtngx marked this conversation as resolved.
Show resolved Hide resolved
return datamodule

def state_dict(self) -> Dict[str, Any]:
Expand All @@ -182,8 +184,8 @@ def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
hparams_file: Optional[_PATH] = None,
**kwargs,
):
**kwargs: Any,
) -> Any:
otaj marked this conversation as resolved.
Show resolved Hide resolved
r"""
Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.
Expand Down
7 changes: 5 additions & 2 deletions src/pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def _get_abbrev_qualified_cls_name(cls: Any) -> str:


def add_argparse_args(
cls: Type["pl.Trainer"], parent_parser: ArgumentParser, *, use_argument_group: bool = True
cls: Type[Union["pl.Trainer", "pl.LightningDataModule"]],
jxtngx marked this conversation as resolved.
Show resolved Hide resolved
parent_parser: ArgumentParser,
*,
use_argument_group: bool = True,
) -> Union[_ArgumentGroup, ArgumentParser]:
r"""Extends existing argparse by default attributes for ``cls``.

Expand Down Expand Up @@ -216,7 +219,7 @@ def add_argparse_args(

ignore_arg_names = ["self", "args", "kwargs"]
if hasattr(cls, "get_deprecated_arg_names"):
ignore_arg_names += cls.get_deprecated_arg_names()
ignore_arg_names += cls.get_deprecated_arg_names() # type: ignore[union-attr]
otaj marked this conversation as resolved.
Show resolved Hide resolved

allowed_types = (str, int, float, bool)

Expand Down