15
15
16
16
import torch
17
17
import torch .utils .data
18
+
19
+ from parameterized import parameterized
18
20
from torch .testing ._internal .common_utils import IS_MACOS , TEST_CUDA , TestCase
19
21
from torchdata .stateful_dataloader import Stateful , StatefulDataLoader
20
22
@@ -1314,7 +1316,7 @@ def test(self):
1314
1316
dataset = dataset ,
1315
1317
num_workers = num_workers ,
1316
1318
collate_fn = identity ,
1317
- multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
1319
+ multiprocessing_context = ( "forkserver" if IS_MACOS and num_workers else None ) ,
1318
1320
)
1319
1321
it = iter (dl )
1320
1322
# Fetch at least one batch from each worker
@@ -1325,7 +1327,10 @@ def test(self):
1325
1327
if num_workers > 0 :
1326
1328
for i in range (num_workers ):
1327
1329
# Ensure worker state is stored only once if the dataset is also the iterator
1328
- self .assertEqual (state_dict ["_snapshot" ]["_worker_snapshots" ][f"worker_{ i } " ]["dataset_state" ], None )
1330
+ self .assertEqual (
1331
+ state_dict ["_snapshot" ]["_worker_snapshots" ][f"worker_{ i } " ]["dataset_state" ],
1332
+ None ,
1333
+ )
1329
1334
self .assertTrue (
1330
1335
state_dict ["_snapshot" ]["_worker_snapshots" ][f"worker_{ i } " ]["fetcher_state" ][
1331
1336
"dataset_iter_state"
@@ -1441,6 +1446,210 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
1441
1446
self ._run_test (17 , 19 )
1442
1447
1443
1448
1449
+ class TestMultiEpochSDL_shard0 (TestCase ):
1450
+ def get_map_dl (self , data_size , num_workers , batch_size , shuffle ):
1451
+ dataset = DummyMapDataset (data_size , shuffle = False )
1452
+ return StatefulDataLoader (
1453
+ dataset = dataset ,
1454
+ num_workers = num_workers ,
1455
+ batch_size = batch_size ,
1456
+ shuffle = shuffle ,
1457
+ multiprocessing_context = ("forkserver" if IS_MACOS and num_workers else None ),
1458
+ )
1459
+
1460
+ def _run (self , data_size , num_workers , batch_size , shuffle ):
1461
+ # For reproducibility of testing, fixing the seed
1462
+ torch .manual_seed (0 )
1463
+ dataloader1 = self .get_map_dl (
1464
+ data_size = data_size ,
1465
+ num_workers = num_workers ,
1466
+ batch_size = batch_size ,
1467
+ shuffle = shuffle ,
1468
+ )
1469
+ # Run through the dataloader for 2 epochs and count the number of items yielded
1470
+ num_items_yielded = 0
1471
+ dataloader1_items = []
1472
+ for _ in range (2 ):
1473
+ for batch in dataloader1 :
1474
+ dataloader1_items .append (batch )
1475
+ num_items_yielded += 1
1476
+ # Save the state dict
1477
+ state_dict = dataloader1 .state_dict ()
1478
+ # Create a new StatefulDataLoader instance and load the state dict
1479
+ new_dataloader1 = self .get_map_dl (
1480
+ data_size = data_size ,
1481
+ num_workers = num_workers ,
1482
+ batch_size = batch_size ,
1483
+ shuffle = shuffle ,
1484
+ )
1485
+ new_dataloader1 .load_state_dict (state_dict )
1486
+ # Run through the new dataloader for another 2 epochs and count the number of items yielded
1487
+ additional_num_items_yielded = 0
1488
+ for i in range (2 ):
1489
+ epoch_num_items_yielded = 0
1490
+ for batch in new_dataloader1 :
1491
+ dataloader1_items .append (batch )
1492
+ epoch_num_items_yielded += 1
1493
+ additional_num_items_yielded += epoch_num_items_yielded
1494
+ # Check that the total number of items yielded is correct
1495
+ self .assertEqual (num_items_yielded + additional_num_items_yielded , data_size * 4 )
1496
+
1497
+ # now run a second dataloder for 4 epochs and check if the order is same.
1498
+ # we need to fix the seed again since we want to bring the initial conditions to the same state as at the time of instantiating the first dataloader
1499
+ torch .manual_seed (0 )
1500
+ dataloader2 = self .get_map_dl (
1501
+ data_size = data_size ,
1502
+ num_workers = num_workers ,
1503
+ batch_size = batch_size ,
1504
+ shuffle = shuffle ,
1505
+ )
1506
+ dataloader2_items = []
1507
+ for _ in range (4 ):
1508
+ for batch in dataloader2 :
1509
+ dataloader2_items .append (batch )
1510
+
1511
+ self .assertEqual (dataloader1_items , dataloader2_items )
1512
+
1513
+ @parameterized .expand (itertools .product ([100 ], [0 , 2 ], [1 ], [False , True ]))
1514
+ def test_multi_epoch_sdl (self , datasize , num_workers , batch_size , shuffle ):
1515
+ self ._run (datasize , num_workers , batch_size , shuffle )
1516
+
1517
+
1518
+ class TestEndOfEpochBehavior_shard0 (TestCase ):
1519
+ def get_map_dl (self , data_size , num_workers , batch_size , shuffle ):
1520
+ dataset = DummyMapDataset (data_size , shuffle = False )
1521
+ return StatefulDataLoader (
1522
+ dataset = dataset ,
1523
+ num_workers = num_workers ,
1524
+ batch_size = batch_size ,
1525
+ shuffle = shuffle ,
1526
+ multiprocessing_context = ("forkserver" if IS_MACOS and num_workers else None ),
1527
+ )
1528
+
1529
+ def _count_items_yielded (self , data_loader : StatefulDataLoader ) -> int :
1530
+ num_items_yielded = 0
1531
+ for batch in data_loader :
1532
+ num_items_yielded += 1
1533
+ return num_items_yielded
1534
+
1535
+ def _run (self , data_size , num_workers , batch_size , shuffle ):
1536
+ dataloader = self .get_map_dl (
1537
+ data_size = data_size ,
1538
+ num_workers = num_workers ,
1539
+ batch_size = batch_size ,
1540
+ shuffle = shuffle ,
1541
+ )
1542
+ # Run through the dataloader for 1 epoch and count the number of items yielded
1543
+ num_items_yielded = 0
1544
+
1545
+ for batch in dataloader :
1546
+ num_items_yielded += 1
1547
+ sd_in = dataloader .state_dict ()
1548
+ sd_out = dataloader .state_dict ()
1549
+
1550
+ self .assertEqual (num_items_yielded , data_size )
1551
+
1552
+ # Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch
1553
+ dataloader_sd_in = self .get_map_dl (
1554
+ data_size = data_size ,
1555
+ num_workers = num_workers ,
1556
+ batch_size = batch_size ,
1557
+ shuffle = shuffle ,
1558
+ )
1559
+ dataloader_sd_in .load_state_dict (sd_in )
1560
+
1561
+ # Run through the new dataloader for 1 epoch and count the number of items yielded
1562
+ # num_items_yielded should be 0 since the state dict was saved before the end of epoch
1563
+ num_items_yielded = self ._count_items_yielded (dataloader_sd_in )
1564
+ self .assertEqual (num_items_yielded , 0 )
1565
+
1566
+ # Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch
1567
+ dataloader_sd_out = self .get_map_dl (
1568
+ data_size = data_size ,
1569
+ num_workers = num_workers ,
1570
+ batch_size = batch_size ,
1571
+ shuffle = shuffle ,
1572
+ )
1573
+ dataloader_sd_out .load_state_dict (sd_out )
1574
+
1575
+ # Run through the new dataloader for 1 epoch and count the number of items yielded
1576
+ # num_items_yielded should be data_size since the state dict was saved after the end of epoch
1577
+ num_items_yielded = self ._count_items_yielded (dataloader_sd_out )
1578
+ self .assertEqual (num_items_yielded , data_size )
1579
+
1580
+ @parameterized .expand (itertools .product ([100 ], [0 , 2 ], [1 ], [False , True ]))
1581
+ def test_end_of_epoch_behavior (self , datasize , num_workers , batch_size , shuffle ):
1582
+ self ._run (datasize , num_workers , batch_size , shuffle )
1583
+
1584
+
1585
+ class TestNotStatefulSamplerSDL_shard0 (TestCase ):
1586
+ def get_map_dl (self , data_size , num_workers , batch_size , sampler_cls ):
1587
+ dataset = DummyMapDataset (data_size , shuffle = False )
1588
+ sampler = sampler_cls (dataset )
1589
+ return StatefulDataLoader (
1590
+ dataset = dataset ,
1591
+ num_workers = num_workers ,
1592
+ batch_size = batch_size ,
1593
+ sampler = sampler ,
1594
+ multiprocessing_context = ("forkserver" if IS_MACOS and num_workers else None ),
1595
+ )
1596
+
1597
+ def _run (self , data_size , num_workers , batch_size , interrupt , sampler_cls ):
1598
+ torch .manual_seed (0 ) # Fixing seed for deterministic results
1599
+ dataloader1 = self .get_map_dl (
1600
+ data_size = data_size ,
1601
+ num_workers = num_workers ,
1602
+ batch_size = batch_size ,
1603
+ sampler_cls = sampler_cls ,
1604
+ )
1605
+ # interrupt the dataloader after interrupt batches and save the state dict
1606
+ results_dataloader1 = []
1607
+ for i , batch in enumerate (dataloader1 ):
1608
+ results_dataloader1 .append (batch )
1609
+ if i == interrupt :
1610
+ break
1611
+ state_dict = dataloader1 .state_dict ()
1612
+
1613
+ torch .manual_seed (
1614
+ 0
1615
+ ) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before
1616
+ resumed_dataloader1 = self .get_map_dl (
1617
+ data_size = data_size ,
1618
+ num_workers = num_workers ,
1619
+ batch_size = batch_size ,
1620
+ sampler_cls = sampler_cls ,
1621
+ )
1622
+ resumed_dataloader1 .load_state_dict (state_dict )
1623
+
1624
+ for batch in resumed_dataloader1 :
1625
+ results_dataloader1 .append (batch )
1626
+
1627
+ # now start a completely new dataloader and get all the batches
1628
+ torch .manual_seed (0 )
1629
+ dataloader2 = self .get_map_dl (
1630
+ data_size = data_size ,
1631
+ num_workers = num_workers ,
1632
+ batch_size = batch_size ,
1633
+ sampler_cls = sampler_cls ,
1634
+ )
1635
+ results_dataloader2 = []
1636
+ for batch in dataloader2 :
1637
+ results_dataloader2 .append (batch )
1638
+ self .assertEqual (results_dataloader1 , results_dataloader2 )
1639
+
1640
+ @parameterized .expand (
1641
+ itertools .product (
1642
+ [100 ],
1643
+ [0 , 2 ],
1644
+ [1 ],
1645
+ [10 , 50 , 80 ],
1646
+ [torch .utils .data .RandomSampler , torch .utils .data .SequentialSampler ],
1647
+ )
1648
+ )
1649
+ def test_notstatefulSDL (self , data_size , num_workers , batch_size , interrupt , sampler_cls ):
1650
+ self ._run (100 , 0 , 1 , interrupt , sampler_cls )
1651
+
1652
+
1444
1653
class TestMultiEpochState_shard0 (TestCase ):
1445
1654
def get_iterable_dl (self , pw , num_workers ):
1446
1655
data_size = [25 , 50 , 100 , 75 ]
0 commit comments