@@ -1927,12 +1927,13 @@ def test_sampler_without_rep_state_dict(self, backend):
1927
1927
s = new_replay_buffer .sample (batch_size = 1 )
1928
1928
assert (s .exclude ("index" ) == 0 ).all ()
1929
1929
1930
- def test_sampler_without_replacement_cap_prefetch (self ):
1930
+ @pytest .mark .parametrize ("drop_last" , [False , True ])
1931
+ def test_sampler_without_replacement_cap_prefetch (self , drop_last ):
1931
1932
torch .manual_seed (0 )
1932
- data = TensorDict ({"a" : torch .arange (10 )}, batch_size = [10 ])
1933
+ data = TensorDict ({"a" : torch .arange (11 )}, batch_size = [11 ])
1933
1934
rb = ReplayBuffer (
1934
- storage = LazyTensorStorage (10 ),
1935
- sampler = SamplerWithoutReplacement (),
1935
+ storage = LazyTensorStorage (11 ),
1936
+ sampler = SamplerWithoutReplacement (drop_last = drop_last ),
1936
1937
batch_size = 2 ,
1937
1938
prefetch = 3 ,
1938
1939
)
@@ -1941,10 +1942,13 @@ def test_sampler_without_replacement_cap_prefetch(self):
1941
1942
for _ in range (100 ):
1942
1943
s = set ()
1943
1944
for i , d in enumerate (rb ):
1944
- assert i <= 4
1945
+ assert i <= ( 4 + int ( not drop_last )), i
1945
1946
s = s .union (set (d ["a" ].tolist ()))
1946
- assert i == 4
1947
- assert s == set (range (10 ))
1947
+ assert i == (4 + int (not drop_last )), i
1948
+ if drop_last :
1949
+ assert s != set (range (11 ))
1950
+ else :
1951
+ assert s == set (range (11 ))
1948
1952
1949
1953
@pytest .mark .parametrize (
1950
1954
"batch_size,num_slices,slice_len,prioritized" ,
0 commit comments