Skip to content

Commit f15fd3a

Browse files
Fix end of epoch StatefulDataLoader restart (#1439)
* add test for end of epoch state dict check * run precommit update stateful_dataloader run precommit local changes update test to test the order of batches update test update tests revert changes in SDL revert changes in SDL update tests run precommit * update sampler * run precommit * remove unnecessary comment * add test for statedict before and after endofepoch * run precommit * check if _sampler_iter is exhausted * run precommit * remove commented lines * remove default values * only exhaust sampler_iter if present in sd * update _StatefulRandomSamplerIterator update state dict if the iterator has finished add comment about why were updating state dict run precommit * update randomsampleriter state_dict fully * run precommit * fork torch.utils.data RandomSampler reverse changes to sdl.py generator to iterator run precommit update generator usage * update class name * run precommit * add a method to generate permutations * update return type * update next logic * add comment * update tests to include non stateful samplers * add comments
1 parent fe6b405 commit f15fd3a

File tree

3 files changed

+349
-75
lines changed

3 files changed

+349
-75
lines changed

test/stateful_dataloader/test_state_dict.py

Lines changed: 207 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import torch
1717
import torch.utils.data
18+
19+
from parameterized import parameterized
1820
from torch.testing._internal.common_utils import IS_MACOS, TEST_CUDA, TestCase
1921
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader
2022

@@ -1314,7 +1316,7 @@ def test(self):
13141316
dataset=dataset,
13151317
num_workers=num_workers,
13161318
collate_fn=identity,
1317-
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
1319+
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
13181320
)
13191321
it = iter(dl)
13201322
# Fetch at least one batch from each worker
@@ -1325,7 +1327,10 @@ def test(self):
13251327
if num_workers > 0:
13261328
for i in range(num_workers):
13271329
# Ensure worker state is stored only once if the dataset is also the iterator
1328-
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
1330+
self.assertEqual(
1331+
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"],
1332+
None,
1333+
)
13291334
self.assertTrue(
13301335
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
13311336
"dataset_iter_state"
@@ -1441,6 +1446,206 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
14411446
self._run_test(17, 19)
14421447

14431448

1449+
class TestMultiEpochSDL_shard0(TestCase):
1450+
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
1451+
dataset = DummyMapDataset(data_size, shuffle=False)
1452+
return StatefulDataLoader(
1453+
dataset=dataset,
1454+
num_workers=num_workers,
1455+
batch_size=batch_size,
1456+
shuffle=shuffle,
1457+
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
1458+
)
1459+
1460+
def _run(self, data_size, num_workers, batch_size, shuffle):
1461+
dataloader1 = self.get_map_dl(
1462+
data_size=data_size,
1463+
num_workers=num_workers,
1464+
batch_size=batch_size,
1465+
shuffle=shuffle,
1466+
)
1467+
# Run through the dataloader for 2 epochs and count the number of items yielded
1468+
num_items_yielded = 0
1469+
dataloader1_items = []
1470+
for _ in range(2):
1471+
for batch in dataloader1:
1472+
dataloader1_items.append(batch)
1473+
num_items_yielded += 1
1474+
# Save the state dict
1475+
state_dict = dataloader1.state_dict()
1476+
# Create a new StatefulDataLoader instance and load the state dict
1477+
new_dataloader1 = self.get_map_dl(
1478+
data_size=data_size,
1479+
num_workers=num_workers,
1480+
batch_size=batch_size,
1481+
shuffle=shuffle,
1482+
)
1483+
new_dataloader1.load_state_dict(state_dict)
1484+
# Run through the new dataloader for another 2 epochs and count the number of items yielded
1485+
additional_num_items_yielded = 0
1486+
for i in range(2):
1487+
epoch_num_items_yielded = 0
1488+
for batch in new_dataloader1:
1489+
dataloader1_items.append(batch)
1490+
epoch_num_items_yielded += 1
1491+
additional_num_items_yielded += epoch_num_items_yielded
1492+
# Check that the total number of items yielded is correct
1493+
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)
1494+
1495+
# now run a second dataloder for 4 epochs and check if the order is same.
1496+
dataloader2 = self.get_map_dl(
1497+
data_size=data_size,
1498+
num_workers=num_workers,
1499+
batch_size=batch_size,
1500+
shuffle=shuffle,
1501+
)
1502+
dataloader2_items = []
1503+
for _ in range(4):
1504+
for batch in dataloader2:
1505+
dataloader2_items.append(batch)
1506+
1507+
self.assertEqual(dataloader1_items, dataloader2_items)
1508+
1509+
@parameterized.expand(itertools.product([100], [0, 2], [1], [False, True]))
1510+
def test_multi_epoch_sdl(self, datasize, num_workers, batch_size, shuffle):
1511+
self._run(datasize, num_workers, batch_size, shuffle)
1512+
1513+
1514+
class TestEndOfEpochBehavior_shard0(TestCase):
1515+
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
1516+
dataset = DummyMapDataset(data_size, shuffle=False)
1517+
return StatefulDataLoader(
1518+
dataset=dataset,
1519+
num_workers=num_workers,
1520+
batch_size=batch_size,
1521+
shuffle=shuffle,
1522+
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
1523+
)
1524+
1525+
def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int:
1526+
num_items_yielded = 0
1527+
for batch in data_loader:
1528+
num_items_yielded += 1
1529+
return num_items_yielded
1530+
1531+
def _run(self, data_size, num_workers, batch_size, shuffle):
1532+
dataloader = self.get_map_dl(
1533+
data_size=data_size,
1534+
num_workers=num_workers,
1535+
batch_size=batch_size,
1536+
shuffle=shuffle,
1537+
)
1538+
# Run through the dataloader for 1 epoch and count the number of items yielded
1539+
num_items_yielded = 0
1540+
1541+
for batch in dataloader:
1542+
num_items_yielded += 1
1543+
sd_in = dataloader.state_dict()
1544+
sd_out = dataloader.state_dict()
1545+
1546+
self.assertEqual(num_items_yielded, data_size)
1547+
1548+
# Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch
1549+
dataloader_sd_in = self.get_map_dl(
1550+
data_size=data_size,
1551+
num_workers=num_workers,
1552+
batch_size=batch_size,
1553+
shuffle=shuffle,
1554+
)
1555+
dataloader_sd_in.load_state_dict(sd_in)
1556+
1557+
# Run through the new dataloader for 1 epoch and count the number of items yielded
1558+
# num_items_yielded should be 0 since the state dict was saved before the end of epoch
1559+
num_items_yielded = self._count_items_yielded(dataloader_sd_in)
1560+
self.assertEqual(num_items_yielded, 0)
1561+
1562+
# Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch
1563+
dataloader_sd_out = self.get_map_dl(
1564+
data_size=data_size,
1565+
num_workers=num_workers,
1566+
batch_size=batch_size,
1567+
shuffle=shuffle,
1568+
)
1569+
dataloader_sd_out.load_state_dict(sd_out)
1570+
1571+
# Run through the new dataloader for 1 epoch and count the number of items yielded
1572+
# num_items_yielded should be data_size since the state dict was saved after the end of epoch
1573+
num_items_yielded = self._count_items_yielded(dataloader_sd_out)
1574+
self.assertEqual(num_items_yielded, data_size)
1575+
1576+
@parameterized.expand(itertools.product([100], [0, 2], [1], [False, True]))
1577+
def test_end_of_epoch_behavior(self, datasize, num_workers, batch_size, shuffle):
1578+
self._run(datasize, num_workers, batch_size, shuffle)
1579+
1580+
1581+
class TestNotStatefulSamplerSDL_shard0(TestCase):
1582+
def get_map_dl(self, data_size, num_workers, batch_size, sampler_cls):
1583+
dataset = DummyMapDataset(data_size, shuffle=False)
1584+
sampler = sampler_cls(dataset)
1585+
return StatefulDataLoader(
1586+
dataset=dataset,
1587+
num_workers=num_workers,
1588+
batch_size=batch_size,
1589+
sampler=sampler,
1590+
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
1591+
)
1592+
1593+
def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls):
1594+
torch.manual_seed(0) # Fixing seed for deterministic results
1595+
dataloader1 = self.get_map_dl(
1596+
data_size=data_size,
1597+
num_workers=num_workers,
1598+
batch_size=batch_size,
1599+
sampler_cls=sampler_cls,
1600+
)
1601+
# interrupt the dataloader after interrupt batches and save the state dict
1602+
results_dataloader1 = []
1603+
for i, batch in enumerate(dataloader1):
1604+
results_dataloader1.append(batch)
1605+
if i == interrupt:
1606+
break
1607+
state_dict = dataloader1.state_dict()
1608+
1609+
torch.manual_seed(
1610+
0
1611+
) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before
1612+
resumed_dataloader1 = self.get_map_dl(
1613+
data_size=data_size,
1614+
num_workers=num_workers,
1615+
batch_size=batch_size,
1616+
sampler_cls=sampler_cls,
1617+
)
1618+
resumed_dataloader1.load_state_dict(state_dict)
1619+
1620+
for batch in resumed_dataloader1:
1621+
results_dataloader1.append(batch)
1622+
1623+
# now start a completely new dataloader and get all the batches
1624+
torch.manual_seed(0)
1625+
dataloader2 = self.get_map_dl(
1626+
data_size=data_size,
1627+
num_workers=num_workers,
1628+
batch_size=batch_size,
1629+
sampler_cls=sampler_cls,
1630+
)
1631+
results_dataloader2 = []
1632+
for batch in dataloader2:
1633+
results_dataloader2.append(batch)
1634+
self.assertEqual(results_dataloader1, results_dataloader2)
1635+
1636+
@parameterized.expand(
1637+
itertools.product(
1638+
[100],
1639+
[0, 2],
1640+
[1],
1641+
[10, 50, 80],
1642+
[torch.utils.data.RandomSampler, torch.utils.data.SequentialSampler],
1643+
)
1644+
)
1645+
def test_notstatefulSDL(self, data_size, num_workers, batch_size, interrupt, sampler_cls):
1646+
self._run(100, 0, 1, interrupt, sampler_cls)
1647+
1648+
14441649
class TestMultiEpochState_shard0(TestCase):
14451650
def get_iterable_dl(self, pw, num_workers):
14461651
data_size = [25, 50, 100, 75]

0 commit comments

Comments
 (0)