15
15
16
16
import torch
17
17
import torch .utils .data
18
+
19
+ from parameterized import parameterized
18
20
from torch .testing ._internal .common_utils import IS_MACOS , TEST_CUDA , TestCase
19
21
from torchdata .stateful_dataloader import Stateful , StatefulDataLoader
20
22
@@ -1314,7 +1316,7 @@ def test(self):
1314
1316
dataset = dataset ,
1315
1317
num_workers = num_workers ,
1316
1318
collate_fn = identity ,
1317
- multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
1319
+ multiprocessing_context = ( "forkserver" if IS_MACOS and num_workers else None ) ,
1318
1320
)
1319
1321
it = iter (dl )
1320
1322
# Fetch at least one batch from each worker
@@ -1325,7 +1327,10 @@ def test(self):
1325
1327
if num_workers > 0 :
1326
1328
for i in range (num_workers ):
1327
1329
# Ensure worker state is stored only once if the dataset is also the iterator
1328
- self .assertEqual (state_dict ["_snapshot" ]["_worker_snapshots" ][f"worker_{ i } " ]["dataset_state" ], None )
1330
+ self .assertEqual (
1331
+ state_dict ["_snapshot" ]["_worker_snapshots" ][f"worker_{ i } " ]["dataset_state" ],
1332
+ None ,
1333
+ )
1329
1334
self .assertTrue (
1330
1335
state_dict ["_snapshot" ]["_worker_snapshots" ][f"worker_{ i } " ]["fetcher_state" ][
1331
1336
"dataset_iter_state"
@@ -1441,6 +1446,206 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
1441
1446
self ._run_test (17 , 19 )
1442
1447
1443
1448
1449
+ class TestMultiEpochSDL_shard0 (TestCase ):
1450
+ def get_map_dl (self , data_size , num_workers , batch_size , shuffle ):
1451
+ dataset = DummyMapDataset (data_size , shuffle = False )
1452
+ return StatefulDataLoader (
1453
+ dataset = dataset ,
1454
+ num_workers = num_workers ,
1455
+ batch_size = batch_size ,
1456
+ shuffle = shuffle ,
1457
+ multiprocessing_context = ("forkserver" if IS_MACOS and num_workers else None ),
1458
+ )
1459
+
1460
+ def _run (self , data_size , num_workers , batch_size , shuffle ):
1461
+ dataloader1 = self .get_map_dl (
1462
+ data_size = data_size ,
1463
+ num_workers = num_workers ,
1464
+ batch_size = batch_size ,
1465
+ shuffle = shuffle ,
1466
+ )
1467
+ # Run through the dataloader for 2 epochs and count the number of items yielded
1468
+ num_items_yielded = 0
1469
+ dataloader1_items = []
1470
+ for _ in range (2 ):
1471
+ for batch in dataloader1 :
1472
+ dataloader1_items .append (batch )
1473
+ num_items_yielded += 1
1474
+ # Save the state dict
1475
+ state_dict = dataloader1 .state_dict ()
1476
+ # Create a new StatefulDataLoader instance and load the state dict
1477
+ new_dataloader1 = self .get_map_dl (
1478
+ data_size = data_size ,
1479
+ num_workers = num_workers ,
1480
+ batch_size = batch_size ,
1481
+ shuffle = shuffle ,
1482
+ )
1483
+ new_dataloader1 .load_state_dict (state_dict )
1484
+ # Run through the new dataloader for another 2 epochs and count the number of items yielded
1485
+ additional_num_items_yielded = 0
1486
+ for i in range (2 ):
1487
+ epoch_num_items_yielded = 0
1488
+ for batch in new_dataloader1 :
1489
+ dataloader1_items .append (batch )
1490
+ epoch_num_items_yielded += 1
1491
+ additional_num_items_yielded += epoch_num_items_yielded
1492
+ # Check that the total number of items yielded is correct
1493
+ self .assertEqual (num_items_yielded + additional_num_items_yielded , data_size * 4 )
1494
+
1495
+ # now run a second dataloder for 4 epochs and check if the order is same.
1496
+ dataloader2 = self .get_map_dl (
1497
+ data_size = data_size ,
1498
+ num_workers = num_workers ,
1499
+ batch_size = batch_size ,
1500
+ shuffle = shuffle ,
1501
+ )
1502
+ dataloader2_items = []
1503
+ for _ in range (4 ):
1504
+ for batch in dataloader2 :
1505
+ dataloader2_items .append (batch )
1506
+
1507
+ self .assertEqual (dataloader1_items , dataloader2_items )
1508
+
1509
+ @parameterized .expand (itertools .product ([100 ], [0 , 2 ], [1 ], [False , True ]))
1510
+ def test_multi_epoch_sdl (self , datasize , num_workers , batch_size , shuffle ):
1511
+ self ._run (datasize , num_workers , batch_size , shuffle )
1512
+
1513
+
1514
+ class TestEndOfEpochBehavior_shard0 (TestCase ):
1515
+ def get_map_dl (self , data_size , num_workers , batch_size , shuffle ):
1516
+ dataset = DummyMapDataset (data_size , shuffle = False )
1517
+ return StatefulDataLoader (
1518
+ dataset = dataset ,
1519
+ num_workers = num_workers ,
1520
+ batch_size = batch_size ,
1521
+ shuffle = shuffle ,
1522
+ multiprocessing_context = ("forkserver" if IS_MACOS and num_workers else None ),
1523
+ )
1524
+
1525
+ def _count_items_yielded (self , data_loader : StatefulDataLoader ) -> int :
1526
+ num_items_yielded = 0
1527
+ for batch in data_loader :
1528
+ num_items_yielded += 1
1529
+ return num_items_yielded
1530
+
1531
+ def _run (self , data_size , num_workers , batch_size , shuffle ):
1532
+ dataloader = self .get_map_dl (
1533
+ data_size = data_size ,
1534
+ num_workers = num_workers ,
1535
+ batch_size = batch_size ,
1536
+ shuffle = shuffle ,
1537
+ )
1538
+ # Run through the dataloader for 1 epoch and count the number of items yielded
1539
+ num_items_yielded = 0
1540
+
1541
+ for batch in dataloader :
1542
+ num_items_yielded += 1
1543
+ sd_in = dataloader .state_dict ()
1544
+ sd_out = dataloader .state_dict ()
1545
+
1546
+ self .assertEqual (num_items_yielded , data_size )
1547
+
1548
+ # Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch
1549
+ dataloader_sd_in = self .get_map_dl (
1550
+ data_size = data_size ,
1551
+ num_workers = num_workers ,
1552
+ batch_size = batch_size ,
1553
+ shuffle = shuffle ,
1554
+ )
1555
+ dataloader_sd_in .load_state_dict (sd_in )
1556
+
1557
+ # Run through the new dataloader for 1 epoch and count the number of items yielded
1558
+ # num_items_yielded should be 0 since the state dict was saved before the end of epoch
1559
+ num_items_yielded = self ._count_items_yielded (dataloader_sd_in )
1560
+ self .assertEqual (num_items_yielded , 0 )
1561
+
1562
+ # Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch
1563
+ dataloader_sd_out = self .get_map_dl (
1564
+ data_size = data_size ,
1565
+ num_workers = num_workers ,
1566
+ batch_size = batch_size ,
1567
+ shuffle = shuffle ,
1568
+ )
1569
+ dataloader_sd_out .load_state_dict (sd_out )
1570
+
1571
+ # Run through the new dataloader for 1 epoch and count the number of items yielded
1572
+ # num_items_yielded should be data_size since the state dict was saved after the end of epoch
1573
+ num_items_yielded = self ._count_items_yielded (dataloader_sd_out )
1574
+ self .assertEqual (num_items_yielded , data_size )
1575
+
1576
+ @parameterized .expand (itertools .product ([100 ], [0 , 2 ], [1 ], [False , True ]))
1577
+ def test_end_of_epoch_behavior (self , datasize , num_workers , batch_size , shuffle ):
1578
+ self ._run (datasize , num_workers , batch_size , shuffle )
1579
+
1580
+
1581
+ class TestNotStatefulSamplerSDL_shard0 (TestCase ):
1582
+ def get_map_dl (self , data_size , num_workers , batch_size , sampler_cls ):
1583
+ dataset = DummyMapDataset (data_size , shuffle = False )
1584
+ sampler = sampler_cls (dataset )
1585
+ return StatefulDataLoader (
1586
+ dataset = dataset ,
1587
+ num_workers = num_workers ,
1588
+ batch_size = batch_size ,
1589
+ sampler = sampler ,
1590
+ multiprocessing_context = ("forkserver" if IS_MACOS and num_workers else None ),
1591
+ )
1592
+
1593
+ def _run (self , data_size , num_workers , batch_size , interrupt , sampler_cls ):
1594
+ torch .manual_seed (0 ) # Fixing seed for deterministic results
1595
+ dataloader1 = self .get_map_dl (
1596
+ data_size = data_size ,
1597
+ num_workers = num_workers ,
1598
+ batch_size = batch_size ,
1599
+ sampler_cls = sampler_cls ,
1600
+ )
1601
+ # interrupt the dataloader after interrupt batches and save the state dict
1602
+ results_dataloader1 = []
1603
+ for i , batch in enumerate (dataloader1 ):
1604
+ results_dataloader1 .append (batch )
1605
+ if i == interrupt :
1606
+ break
1607
+ state_dict = dataloader1 .state_dict ()
1608
+
1609
+ torch .manual_seed (
1610
+ 0
1611
+ ) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before
1612
+ resumed_dataloader1 = self .get_map_dl (
1613
+ data_size = data_size ,
1614
+ num_workers = num_workers ,
1615
+ batch_size = batch_size ,
1616
+ sampler_cls = sampler_cls ,
1617
+ )
1618
+ resumed_dataloader1 .load_state_dict (state_dict )
1619
+
1620
+ for batch in resumed_dataloader1 :
1621
+ results_dataloader1 .append (batch )
1622
+
1623
+ # now start a completely new dataloader and get all the batches
1624
+ torch .manual_seed (0 )
1625
+ dataloader2 = self .get_map_dl (
1626
+ data_size = data_size ,
1627
+ num_workers = num_workers ,
1628
+ batch_size = batch_size ,
1629
+ sampler_cls = sampler_cls ,
1630
+ )
1631
+ results_dataloader2 = []
1632
+ for batch in dataloader2 :
1633
+ results_dataloader2 .append (batch )
1634
+ self .assertEqual (results_dataloader1 , results_dataloader2 )
1635
+
1636
+ @parameterized .expand (
1637
+ itertools .product (
1638
+ [100 ],
1639
+ [0 , 2 ],
1640
+ [1 ],
1641
+ [10 , 50 , 80 ],
1642
+ [torch .utils .data .RandomSampler , torch .utils .data .SequentialSampler ],
1643
+ )
1644
+ )
1645
+ def test_notstatefulSDL (self , data_size , num_workers , batch_size , interrupt , sampler_cls ):
1646
+ self ._run (100 , 0 , 1 , interrupt , sampler_cls )
1647
+
1648
+
1444
1649
class TestMultiEpochState_shard0 (TestCase ):
1445
1650
def get_iterable_dl (self , pw , num_workers ):
1446
1651
data_size = [25 , 50 , 100 , 75 ]
0 commit comments