Skip to content

Commit c4177af

Browse files
Cherry-picking changes in main to the release branch (#1446)
* Fix end of epoch StatefulDataLoader restart (#1439) * add test for end of epoch state dict check * run precommit update stateful_dataloader run precommit local changes update test to test the order of batches update test update tests revert changes in SDL revert changes in SDL update tests run precommit * update sampler * run precommit * remove unnecessary comment * add test for statedict before and after endofepoch * run precommit * check if _sampler_iter is exhausted * run precommit * remove commented lines * remove default values * only exhaust sampler_iter if present in sd * update _StatefulRandomSamplerIterator update state dict if the iterator has finished add comment about why were updating state dict run precommit * update randomsampleriter state_dict fully * run precommit * fork torch.utils.data RandomSampler reverse changes to sdl.py generator to iterator run precommit update generator usage * update class name * run precommit * add a method to generate permutations * update return type * update next logic * add comment * update tests to include non stateful samplers * add comments * Using system generated seed in RandomSampler (#1441) * add new sampler tests * update seed generation in sampler * run precommit * update seed generation * change variable name * update comment * add seed to tests * run precommit
1 parent 89a1c71 commit c4177af

File tree

4 files changed

+412
-85
lines changed

4 files changed

+412
-85
lines changed

test/stateful_dataloader/test_sampler.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.utils.data import Dataset
1515

1616
from torchdata.stateful_dataloader import StatefulDataLoader
17-
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
17+
from torchdata.stateful_dataloader.sampler import RandomSampler, StatefulDistributedSampler
1818

1919

2020
class MockDataset(Dataset):
@@ -34,7 +34,10 @@ def __getitem__(self, idx):
3434
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3535
"fork is not supported. Dying (set die_after_fork=0 to override)",
3636
)
37-
@unittest.skipIf(TEST_WITH_ASAN, "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223")
37+
@unittest.skipIf(
38+
TEST_WITH_ASAN,
39+
"DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
40+
)
3841
class TestDataLoader(TestCase):
3942
def setUp(self):
4043
super().setUp()
@@ -44,7 +47,12 @@ def setUp(self):
4447
def test_initialization_StatefulDistributedSampler(self):
4548

4649
sampler = StatefulDistributedSampler(
47-
self.dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False
50+
self.dataset,
51+
num_replicas=10,
52+
rank=0,
53+
shuffle=False,
54+
seed=42,
55+
drop_last=False,
4856
)
4957
self.assertEqual(sampler.dataset, self.dataset)
5058
self.assertEqual(sampler.num_replicas, 10)
@@ -139,7 +147,8 @@ def test_drop_last_effect(self):
139147
)
140148

141149
self.assertTrue(
142-
len(indices_with_drop) <= len(indices_without_drop), "Drop last should result in fewer or equal indices"
150+
len(indices_with_drop) <= len(indices_without_drop),
151+
"Drop last should result in fewer or equal indices",
143152
)
144153

145154
def test_data_order_with_shuffle(self):
@@ -153,7 +162,11 @@ def test_data_order_with_shuffle(self):
153162
for batch in dataloader:
154163
data_loaded.extend(batch)
155164
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
156-
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
165+
self.assertEqual(
166+
data_loaded,
167+
data_sampled,
168+
"Data loaded by DataLoader should match data sampled by sampler",
169+
)
157170

158171
def test_data_order_without_shuffle(self):
159172
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
@@ -167,8 +180,16 @@ def test_data_order_without_shuffle(self):
167180
for batch in dataloader:
168181
data_loaded.extend(batch)
169182
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
170-
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
171-
self.assertEqual(data_loaded, list(range(100)), "Data loaded by DataLoader should be in original order")
183+
self.assertEqual(
184+
data_loaded,
185+
data_sampled,
186+
"Data loaded by DataLoader should match data sampled by sampler",
187+
)
188+
self.assertEqual(
189+
data_loaded,
190+
list(range(100)),
191+
"Data loaded by DataLoader should be in original order",
192+
)
172193

173194
def test_data_distribution_across_replicas(self):
174195
num_replicas = 5
@@ -181,9 +202,36 @@ def test_data_distribution_across_replicas(self):
181202
data_loaded.extend([int(x.item()) for x in batch])
182203
all_data.extend(data_loaded)
183204
self.assertEqual(
184-
sorted(all_data), list(range(100)), "All data points should be covered exactly once across all replicas"
205+
sorted(all_data),
206+
list(range(100)),
207+
"All data points should be covered exactly once across all replicas",
185208
)
186209

210+
def test_seed_replicability(self):
211+
# Test that the same seed will result in the same data order
212+
# We first pick a random number as seed, then use it to initialize two dataloaders
213+
min_seed, max_seed = 0, 1000 # [min_seed, max_seed)
214+
seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item()
215+
torch.manual_seed(seed)
216+
217+
dataloader1 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
218+
results1 = list(dataloader1)
219+
220+
# Repeat the same process with the same seed
221+
torch.manual_seed(seed)
222+
dataloader2 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
223+
results2 = list(dataloader2)
224+
225+
# Repeat the same process with a different seed, making sure that the seed is different
226+
min_seed, max_seed = 1000, 2000 # [min_seed, max_seed)
227+
seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item()
228+
torch.manual_seed(seed)
229+
dataloader3 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
230+
results3 = list(dataloader3)
231+
232+
self.assertEqual(results1, results2, "Data should be replicable with same seed")
233+
self.assertNotEqual(results1, results3, "Data should not be replicable with different seed")
234+
187235

188236
if __name__ == "__main__":
189237
run_tests()

test/stateful_dataloader/test_state_dict.py

Lines changed: 211 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import torch
1717
import torch.utils.data
18+
19+
from parameterized import parameterized
1820
from torch.testing._internal.common_utils import IS_MACOS, TEST_CUDA, TestCase
1921
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader
2022

@@ -1314,7 +1316,7 @@ def test(self):
13141316
dataset=dataset,
13151317
num_workers=num_workers,
13161318
collate_fn=identity,
1317-
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
1319+
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
13181320
)
13191321
it = iter(dl)
13201322
# Fetch at least one batch from each worker
@@ -1325,7 +1327,10 @@ def test(self):
13251327
if num_workers > 0:
13261328
for i in range(num_workers):
13271329
# Ensure worker state is stored only once if the dataset is also the iterator
1328-
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
1330+
self.assertEqual(
1331+
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"],
1332+
None,
1333+
)
13291334
self.assertTrue(
13301335
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
13311336
"dataset_iter_state"
@@ -1441,6 +1446,210 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
14411446
self._run_test(17, 19)
14421447

14431448

1449+
class TestMultiEpochSDL_shard0(TestCase):
1450+
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
1451+
dataset = DummyMapDataset(data_size, shuffle=False)
1452+
return StatefulDataLoader(
1453+
dataset=dataset,
1454+
num_workers=num_workers,
1455+
batch_size=batch_size,
1456+
shuffle=shuffle,
1457+
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
1458+
)
1459+
1460+
def _run(self, data_size, num_workers, batch_size, shuffle):
1461+
# For reproducibility of testing, fixing the seed
1462+
torch.manual_seed(0)
1463+
dataloader1 = self.get_map_dl(
1464+
data_size=data_size,
1465+
num_workers=num_workers,
1466+
batch_size=batch_size,
1467+
shuffle=shuffle,
1468+
)
1469+
# Run through the dataloader for 2 epochs and count the number of items yielded
1470+
num_items_yielded = 0
1471+
dataloader1_items = []
1472+
for _ in range(2):
1473+
for batch in dataloader1:
1474+
dataloader1_items.append(batch)
1475+
num_items_yielded += 1
1476+
# Save the state dict
1477+
state_dict = dataloader1.state_dict()
1478+
# Create a new StatefulDataLoader instance and load the state dict
1479+
new_dataloader1 = self.get_map_dl(
1480+
data_size=data_size,
1481+
num_workers=num_workers,
1482+
batch_size=batch_size,
1483+
shuffle=shuffle,
1484+
)
1485+
new_dataloader1.load_state_dict(state_dict)
1486+
# Run through the new dataloader for another 2 epochs and count the number of items yielded
1487+
additional_num_items_yielded = 0
1488+
for i in range(2):
1489+
epoch_num_items_yielded = 0
1490+
for batch in new_dataloader1:
1491+
dataloader1_items.append(batch)
1492+
epoch_num_items_yielded += 1
1493+
additional_num_items_yielded += epoch_num_items_yielded
1494+
# Check that the total number of items yielded is correct
1495+
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)
1496+
1497+
# now run a second dataloder for 4 epochs and check if the order is same.
1498+
# we need to fix the seed again since we want to bring the initial conditions to the same state as at the time of instantiating the first dataloader
1499+
torch.manual_seed(0)
1500+
dataloader2 = self.get_map_dl(
1501+
data_size=data_size,
1502+
num_workers=num_workers,
1503+
batch_size=batch_size,
1504+
shuffle=shuffle,
1505+
)
1506+
dataloader2_items = []
1507+
for _ in range(4):
1508+
for batch in dataloader2:
1509+
dataloader2_items.append(batch)
1510+
1511+
self.assertEqual(dataloader1_items, dataloader2_items)
1512+
1513+
@parameterized.expand(itertools.product([100], [0, 2], [1], [False, True]))
1514+
def test_multi_epoch_sdl(self, datasize, num_workers, batch_size, shuffle):
1515+
self._run(datasize, num_workers, batch_size, shuffle)
1516+
1517+
1518+
class TestEndOfEpochBehavior_shard0(TestCase):
1519+
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
1520+
dataset = DummyMapDataset(data_size, shuffle=False)
1521+
return StatefulDataLoader(
1522+
dataset=dataset,
1523+
num_workers=num_workers,
1524+
batch_size=batch_size,
1525+
shuffle=shuffle,
1526+
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
1527+
)
1528+
1529+
def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int:
1530+
num_items_yielded = 0
1531+
for batch in data_loader:
1532+
num_items_yielded += 1
1533+
return num_items_yielded
1534+
1535+
def _run(self, data_size, num_workers, batch_size, shuffle):
1536+
dataloader = self.get_map_dl(
1537+
data_size=data_size,
1538+
num_workers=num_workers,
1539+
batch_size=batch_size,
1540+
shuffle=shuffle,
1541+
)
1542+
# Run through the dataloader for 1 epoch and count the number of items yielded
1543+
num_items_yielded = 0
1544+
1545+
for batch in dataloader:
1546+
num_items_yielded += 1
1547+
sd_in = dataloader.state_dict()
1548+
sd_out = dataloader.state_dict()
1549+
1550+
self.assertEqual(num_items_yielded, data_size)
1551+
1552+
# Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch
1553+
dataloader_sd_in = self.get_map_dl(
1554+
data_size=data_size,
1555+
num_workers=num_workers,
1556+
batch_size=batch_size,
1557+
shuffle=shuffle,
1558+
)
1559+
dataloader_sd_in.load_state_dict(sd_in)
1560+
1561+
# Run through the new dataloader for 1 epoch and count the number of items yielded
1562+
# num_items_yielded should be 0 since the state dict was saved before the end of epoch
1563+
num_items_yielded = self._count_items_yielded(dataloader_sd_in)
1564+
self.assertEqual(num_items_yielded, 0)
1565+
1566+
# Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch
1567+
dataloader_sd_out = self.get_map_dl(
1568+
data_size=data_size,
1569+
num_workers=num_workers,
1570+
batch_size=batch_size,
1571+
shuffle=shuffle,
1572+
)
1573+
dataloader_sd_out.load_state_dict(sd_out)
1574+
1575+
# Run through the new dataloader for 1 epoch and count the number of items yielded
1576+
# num_items_yielded should be data_size since the state dict was saved after the end of epoch
1577+
num_items_yielded = self._count_items_yielded(dataloader_sd_out)
1578+
self.assertEqual(num_items_yielded, data_size)
1579+
1580+
@parameterized.expand(itertools.product([100], [0, 2], [1], [False, True]))
1581+
def test_end_of_epoch_behavior(self, datasize, num_workers, batch_size, shuffle):
1582+
self._run(datasize, num_workers, batch_size, shuffle)
1583+
1584+
1585+
class TestNotStatefulSamplerSDL_shard0(TestCase):
1586+
def get_map_dl(self, data_size, num_workers, batch_size, sampler_cls):
1587+
dataset = DummyMapDataset(data_size, shuffle=False)
1588+
sampler = sampler_cls(dataset)
1589+
return StatefulDataLoader(
1590+
dataset=dataset,
1591+
num_workers=num_workers,
1592+
batch_size=batch_size,
1593+
sampler=sampler,
1594+
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
1595+
)
1596+
1597+
def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls):
1598+
torch.manual_seed(0) # Fixing seed for deterministic results
1599+
dataloader1 = self.get_map_dl(
1600+
data_size=data_size,
1601+
num_workers=num_workers,
1602+
batch_size=batch_size,
1603+
sampler_cls=sampler_cls,
1604+
)
1605+
# interrupt the dataloader after interrupt batches and save the state dict
1606+
results_dataloader1 = []
1607+
for i, batch in enumerate(dataloader1):
1608+
results_dataloader1.append(batch)
1609+
if i == interrupt:
1610+
break
1611+
state_dict = dataloader1.state_dict()
1612+
1613+
torch.manual_seed(
1614+
0
1615+
) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before
1616+
resumed_dataloader1 = self.get_map_dl(
1617+
data_size=data_size,
1618+
num_workers=num_workers,
1619+
batch_size=batch_size,
1620+
sampler_cls=sampler_cls,
1621+
)
1622+
resumed_dataloader1.load_state_dict(state_dict)
1623+
1624+
for batch in resumed_dataloader1:
1625+
results_dataloader1.append(batch)
1626+
1627+
# now start a completely new dataloader and get all the batches
1628+
torch.manual_seed(0)
1629+
dataloader2 = self.get_map_dl(
1630+
data_size=data_size,
1631+
num_workers=num_workers,
1632+
batch_size=batch_size,
1633+
sampler_cls=sampler_cls,
1634+
)
1635+
results_dataloader2 = []
1636+
for batch in dataloader2:
1637+
results_dataloader2.append(batch)
1638+
self.assertEqual(results_dataloader1, results_dataloader2)
1639+
1640+
@parameterized.expand(
1641+
itertools.product(
1642+
[100],
1643+
[0, 2],
1644+
[1],
1645+
[10, 50, 80],
1646+
[torch.utils.data.RandomSampler, torch.utils.data.SequentialSampler],
1647+
)
1648+
)
1649+
def test_notstatefulSDL(self, data_size, num_workers, batch_size, interrupt, sampler_cls):
1650+
self._run(100, 0, 1, interrupt, sampler_cls)
1651+
1652+
14441653
class TestMultiEpochState_shard0(TestCase):
14451654
def get_iterable_dl(self, pw, num_workers):
14461655
data_size = [25, 50, 100, 75]

0 commit comments

Comments
 (0)