Skip to content

Commit 012cf74

Browse files
author
Vincent Moens
committed
[Feature] replay_buffer_chunk
ghstack-source-id: e5f82a7 Pull Request resolved: #2388
1 parent a0c12cd commit 012cf74

File tree

6 files changed

+141
-23
lines changed

6 files changed

+141
-23
lines changed

test/test_collector.py

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from torchrl.collectors.utils import split_trajectories
7070
from torchrl.data import (
7171
Composite,
72+
LazyMemmapStorage,
7273
LazyTensorStorage,
7374
NonTensor,
7475
ReplayBuffer,
@@ -2806,45 +2807,87 @@ def test_collector_rb_sync(self):
28062807
assert assert_allclose_td(rbdata0, rbdata1)
28072808

28082809
@pytest.mark.skipif(not _has_gym, reason="requires gym.")
2809-
def test_collector_rb_multisync(self):
2810-
env = GymEnv(CARTPOLE_VERSIONED())
2811-
env.set_seed(0)
2810+
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
2811+
@pytest.mark.parametrize("env_creator", [False, True])
2812+
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
2813+
def test_collector_rb_multisync(
2814+
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
2815+
):
2816+
if not env_creator:
2817+
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
2818+
env.set_seed(0)
2819+
action_spec = env.action_spec
2820+
env = lambda env=env: env
2821+
else:
2822+
env = EnvCreator(
2823+
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
2824+
StepCounter()
2825+
)
2826+
)
2827+
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]
28122828

2813-
rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
2814-
rb.add(env.rand_step(env.reset()))
2815-
rb.empty()
2829+
if storagetype == LazyMemmapStorage:
2830+
storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
2831+
rb = ReplayBuffer(storage=storagetype(256), batch_size=5)
28162832

28172833
collector = MultiSyncDataCollector(
2818-
[lambda: env, lambda: env],
2819-
RandomPolicy(env.action_spec),
2834+
[env, env],
2835+
RandomPolicy(action_spec),
28202836
replay_buffer=rb,
28212837
total_frames=256,
2822-
frames_per_batch=16,
2838+
frames_per_batch=32,
2839+
replay_buffer_chunk=replay_buffer_chunk,
28232840
)
28242841
torch.manual_seed(0)
28252842
pred_len = 0
28262843
for c in collector:
2827-
pred_len += 16
2844+
pred_len += 32
28282845
assert c is None
28292846
assert len(rb) == pred_len
28302847
collector.shutdown()
28312848
assert len(rb) == 256
2849+
if not replay_buffer_chunk:
2850+
steps_counts = rb["step_count"].squeeze().split(16)
2851+
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
2852+
for step_count, ids in zip(steps_counts, collector_ids):
2853+
step_countdiff = step_count.diff()
2854+
idsdiff = ids.diff()
2855+
assert (
2856+
(step_countdiff == 1) | (step_countdiff < 0)
2857+
).all(), steps_counts
2858+
assert (idsdiff >= 0).all()
28322859

28332860
@pytest.mark.skipif(not _has_gym, reason="requires gym.")
2834-
def test_collector_rb_multiasync(self):
2835-
env = GymEnv(CARTPOLE_VERSIONED())
2836-
env.set_seed(0)
2861+
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
2862+
@pytest.mark.parametrize("env_creator", [False, True])
2863+
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
2864+
def test_collector_rb_multiasync(
2865+
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
2866+
):
2867+
if not env_creator:
2868+
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
2869+
env.set_seed(0)
2870+
action_spec = env.action_spec
2871+
env = lambda env=env: env
2872+
else:
2873+
env = EnvCreator(
2874+
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
2875+
StepCounter()
2876+
)
2877+
)
2878+
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]
28372879

2838-
rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
2839-
rb.add(env.rand_step(env.reset()))
2840-
rb.empty()
2880+
if storagetype == LazyMemmapStorage:
2881+
storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
2882+
rb = ReplayBuffer(storage=storagetype(256), batch_size=5)
28412883

28422884
collector = MultiaSyncDataCollector(
2843-
[lambda: env, lambda: env],
2844-
RandomPolicy(env.action_spec),
2885+
[env, env],
2886+
RandomPolicy(action_spec),
28452887
replay_buffer=rb,
28462888
total_frames=256,
28472889
frames_per_batch=16,
2890+
replay_buffer_chunk=replay_buffer_chunk,
28482891
)
28492892
torch.manual_seed(0)
28502893
pred_len = 0
@@ -2854,6 +2897,16 @@ def test_collector_rb_multiasync(self):
28542897
assert len(rb) >= pred_len
28552898
collector.shutdown()
28562899
assert len(rb) == 256
2900+
if not replay_buffer_chunk:
2901+
steps_counts = rb["step_count"].squeeze().split(16)
2902+
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
2903+
for step_count, ids in zip(steps_counts, collector_ids):
2904+
step_countdiff = step_count.diff()
2905+
idsdiff = ids.diff()
2906+
assert (
2907+
(step_countdiff == 1) | (step_countdiff < 0)
2908+
).all(), steps_counts
2909+
assert (idsdiff >= 0).all()
28572910

28582911

28592912
if __name__ == "__main__":

torchrl/collectors/collectors.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from torchrl.data.tensor_specs import TensorSpec
5555
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
5656
from torchrl.envs.common import _do_nothing, EnvBase
57+
from torchrl.envs.env_creator import EnvCreator
5758
from torchrl.envs.transforms import StepCounter, TransformedEnv
5859
from torchrl.envs.utils import (
5960
_aggregate_end_of_traj,
@@ -1470,6 +1471,7 @@ def __init__(
14701471
set_truncated: bool = False,
14711472
use_buffers: bool | None = None,
14721473
replay_buffer: ReplayBuffer | None = None,
1474+
replay_buffer_chunk: bool = True,
14731475
):
14741476
exploration_type = _convert_exploration_type(
14751477
exploration_mode=exploration_mode, exploration_type=exploration_type
@@ -1514,6 +1516,8 @@ def __init__(
15141516

15151517
self._use_buffers = use_buffers
15161518
self.replay_buffer = replay_buffer
1519+
self._check_replay_buffer_init()
1520+
self.replay_buffer_chunk = replay_buffer_chunk
15171521
if (
15181522
replay_buffer is not None
15191523
and hasattr(replay_buffer, "shared")
@@ -1660,6 +1664,21 @@ def _get_weight_fn(weights=policy_weights):
16601664
)
16611665
self.cat_results = cat_results
16621666

1667+
def _check_replay_buffer_init(self):
1668+
try:
1669+
if not self.replay_buffer._storage.initialized:
1670+
if isinstance(self.create_env_fn, EnvCreator):
1671+
fake_td = self.create_env_fn.tensordict
1672+
else:
1673+
fake_td = self.create_env_fn[0](
1674+
**self.create_env_kwargs[0]
1675+
).fake_tensordict()
1676+
fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long)
1677+
1678+
self.replay_buffer._storage._init(fake_td)
1679+
except AttributeError:
1680+
pass
1681+
16631682
@classmethod
16641683
def _total_workers_from_env(cls, env_creators):
16651684
if isinstance(env_creators, (tuple, list)):
@@ -1795,6 +1814,7 @@ def _run_processes(self) -> None:
17951814
"set_truncated": self.set_truncated,
17961815
"use_buffers": self._use_buffers,
17971816
"replay_buffer": self.replay_buffer,
1817+
"replay_buffer_chunk": self.replay_buffer_chunk,
17981818
"traj_pool": self._traj_pool,
17991819
}
18001820
proc = _ProcessNoWarn(
@@ -2804,6 +2824,7 @@ def _main_async_collector(
28042824
set_truncated: bool = False,
28052825
use_buffers: bool | None = None,
28062826
replay_buffer: ReplayBuffer | None = None,
2827+
replay_buffer_chunk: bool = True,
28072828
traj_pool: _TrajectoryPool = None,
28082829
) -> None:
28092830
pipe_parent.close()
@@ -2825,11 +2846,11 @@ def _main_async_collector(
28252846
env_device=env_device,
28262847
exploration_type=exploration_type,
28272848
reset_when_done=reset_when_done,
2828-
return_same_td=True,
2849+
return_same_td=replay_buffer is None,
28292850
interruptor=interruptor,
28302851
set_truncated=set_truncated,
28312852
use_buffers=use_buffers,
2832-
replay_buffer=replay_buffer,
2853+
replay_buffer=replay_buffer if replay_buffer_chunk else None,
28332854
traj_pool=traj_pool,
28342855
)
28352856
use_buffers = inner_collector._use_buffers
@@ -2895,6 +2916,10 @@ def _main_async_collector(
28952916
continue
28962917

28972918
if replay_buffer is not None:
2919+
if not replay_buffer_chunk:
2920+
next_data.names = None
2921+
replay_buffer.extend(next_data)
2922+
28982923
try:
28992924
queue_out.put((idx, j), timeout=_TIMEOUT)
29002925
if verbose:

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,11 @@ def __len__(self) -> int:
364364
with self._replay_lock:
365365
return len(self._storage)
366366

367+
@property
368+
def write_count(self):
369+
"""The total number of items written so far in the buffer through add and extend."""
370+
return self._writer._write_count
371+
367372
def __repr__(self) -> str:
368373
from torchrl.envs.transforms import Compose
369374

torchrl/data/replay_buffers/storages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def assert_is_sharable(tensor):
562562
raise RuntimeError(STORAGE_ERR)
563563

564564
if is_tensor_collection(storage):
565-
storage.apply(assert_is_sharable)
565+
storage.apply(assert_is_sharable, filter_empty=True)
566566
else:
567567
tree_map(storage, assert_is_sharable)
568568

torchrl/data/replay_buffers/writers.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def add(self, data: Any) -> int | torch.Tensor:
163163
self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
164164
single_data=data
165165
)
166+
self._write_count += 1
166167
# Replicate index requires the shape of the storage to be known
167168
# Other than that, a "flat" (1d) index is ok to write the data
168169
self._storage.set(_cursor, data)
@@ -191,6 +192,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
191192
)
192193
# we need to update the cursor first to avoid race conditions between workers
193194
self._cursor = (batch_size + cur_size) % max_size_along0
195+
self._write_count += batch_size
194196
# Replicate index requires the shape of the storage to be known
195197
# Other than that, a "flat" (1d) index is ok to write the data
196198
self._storage.set(index, data)
@@ -222,6 +224,20 @@ def _cursor(self, value):
222224
_cursor_value = self._cursor_value = mp.Value("i", 0)
223225
_cursor_value.value = value
224226

227+
@property
228+
def _write_count(self):
229+
_write_count = self.__dict__.get("_write_count_value", None)
230+
if _write_count is None:
231+
_write_count = self._write_count_value = mp.Value("i", 0)
232+
return _write_count.value
233+
234+
@_write_count.setter
235+
def _write_count(self, value):
236+
_write_count = self.__dict__.get("_write_count_value", None)
237+
if _write_count is None:
238+
_write_count = self._write_count_value = mp.Value("i", 0)
239+
_write_count.value = value
240+
225241
def __getstate__(self):
226242
state = super().__getstate__()
227243
if get_spawning_popen() is None:
@@ -249,6 +265,7 @@ def add(self, data: Any) -> int | torch.Tensor:
249265
# we need to update the cursor first to avoid race conditions between workers
250266
max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data)
251267
self._cursor = (index + 1) % max_size_along_dim0
268+
self._write_count += 1
252269
if not is_tensorclass(data):
253270
data.set(
254271
"index",
@@ -275,6 +292,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
275292
)
276293
# we need to update the cursor first to avoid race conditions between workers
277294
self._cursor = (batch_size + cur_size) % max_size_along_dim0
295+
self._write_count += batch_size
278296
# storage must convert the data to the appropriate format if needed
279297
if not is_tensorclass(data):
280298
data.set(
@@ -460,6 +478,20 @@ def get_insert_index(self, data: Any) -> int:
460478

461479
return ret
462480

481+
@property
482+
def _write_count(self):
483+
_write_count = self.__dict__.get("_write_count_value", None)
484+
if _write_count is None:
485+
_write_count = self._write_count_value = mp.Value("i", 0)
486+
return _write_count.value
487+
488+
@_write_count.setter
489+
def _write_count(self, value):
490+
_write_count = self.__dict__.get("_write_count_value", None)
491+
if _write_count is None:
492+
_write_count = self._write_count_value = mp.Value("i", 0)
493+
_write_count.value = value
494+
463495
def add(self, data: Any) -> int | torch.Tensor:
464496
"""Inserts a single element of data at an appropriate index, and returns that index.
465497
@@ -469,6 +501,7 @@ def add(self, data: Any) -> int | torch.Tensor:
469501
index = self.get_insert_index(data)
470502
if index is not None:
471503
data.set("index", index)
504+
self._write_count += 1
472505
# Replicate index requires the shape of the storage to be known
473506
# Other than that, a "flat" (1d) index is ok to write the data
474507
self._storage.set(index, data)
@@ -488,6 +521,7 @@ def extend(self, data: TensorDictBase) -> None:
488521
for data_idx, sample in enumerate(data):
489522
storage_idx = self.get_insert_index(sample)
490523
if storage_idx is not None:
524+
self._write_count += 1
491525
data_to_replace[storage_idx] = data_idx
492526

493527
# -1 will be interpreted as invalid by prioritized buffers
@@ -517,7 +551,8 @@ def _empty(self) -> None:
517551
def __getstate__(self):
518552
if get_spawning_popen() is not None:
519553
raise RuntimeError(
520-
f"Writers of type {type(self)} cannot be shared between processes."
554+
f"Writers of type {type(self)} cannot be shared between processes. "
555+
f"Please submit an issue at https://github.com/pytorch/rl if this feature is needed."
521556
)
522557
state = super().__getstate__()
523558
return state

torchrl/envs/env_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def share_memory(self, state_dict: OrderedDict) -> None:
109109
del state_dict[key]
110110

111111
@property
112-
def meta_data(self):
112+
def meta_data(self) -> EnvMetaData:
113113
if self._meta_data is None:
114114
raise RuntimeError(
115115
"meta_data is None in EnvCreator. " "Make sure init_() has been called."

0 commit comments

Comments
 (0)