Skip to content

Commit

Permalink
Refactor BoringFabric in tests (Lightning-AI#19364)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jan 30, 2024
1 parent 28b3806 commit 01f8531
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 141 deletions.
27 changes: 27 additions & 0 deletions tests/tests_fabric/helpers/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Iterator

import torch
from torch import Tensor
from torch.utils.data import Dataset, IterableDataset


class RandomDataset(Dataset):
def __init__(self, size: int, length: int) -> None:
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index: int) -> Tensor:
return self.data[index]

def __len__(self) -> int:
return self.len


class RandomIterableDataset(IterableDataset):
def __init__(self, size: int, count: int) -> None:
self.count = count
self.size = size

def __iter__(self) -> Iterator[Tensor]:
for _ in range(self.count):
yield torch.randn(self.size)
76 changes: 0 additions & 76 deletions tests/tests_fabric/helpers/models.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lightning.fabric.strategies import DeepSpeedStrategy
from torch.utils.data import DataLoader

from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset
from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset
from tests_fabric.helpers.runif import RunIf
from tests_fabric.test_fabric import BoringModel

Expand Down
Loading

0 comments on commit 01f8531

Please sign in to comment.