Skip to content

Commit 1647fa4

Browse files
author
Vincent Moens
authored
[BugFix] Fix flaky rb tests (#1901)
1 parent 1bd5ec6 commit 1647fa4

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

test/test_rb.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,8 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
672672
def test_storage_dumps_loads(
673673
self, device_data, storage_type, data_type, isinit, tmpdir
674674
):
675+
torch.manual_seed(0)
676+
675677
dir_rb = tmpdir / "rb"
676678
dir_save = tmpdir / "save"
677679
dir_rb.mkdir()
@@ -716,25 +718,30 @@ class TC:
716718
)
717719
else:
718720
raise NotImplementedError
721+
719722
if storage_type in (LazyMemmapStorage,):
720723
storage = storage_type(max_size=10, scratch_dir=dir_rb)
721724
else:
722725
storage = storage_type(max_size=10)
726+
723727
# We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index
724728
if data_type == "pytree":
725729
storage.set(range(3), tree_map(lambda x: x.cpu(), data))
726730
else:
727731
storage.set(range(3), data.cpu())
732+
728733
storage.dumps(dir_save)
729734
# check we can dump twice
730735
storage.dumps(dir_save)
731736

732737
storage_recover = storage_type(max_size=10)
733738
if isinit:
734739
if data_type == "pytree":
735-
storage_recover.set(range(3), tree_map(lambda x: x.cpu().zero_(), data))
740+
storage_recover.set(
741+
range(3), tree_map(lambda x: x.cpu().clone().zero_(), data)
742+
)
736743
else:
737-
storage_recover.set(range(3), data.cpu().zero_())
744+
storage_recover.set(range(3), data.cpu().clone().zero_())
738745

739746
if data_type in ("tensor", "pytree") and not isinit:
740747
with pytest.raises(

0 commit comments

Comments
 (0)