Skip to content

[Refactor] Refactor multi-sync collectors #1505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def env_fn():
frames_per_batch=frames_per_batch,
max_frames_per_traj=1000,
total_frames=frames_per_batch * 100,
cat_dim=-1,
)
ccollector.set_seed(seed)
for i, b in enumerate(ccollector):
Expand Down Expand Up @@ -800,7 +801,10 @@ def test_collector_vecnorm_envcreator(static_seed):
policy = RandomPolicy(env_make.action_spec)
num_data_collectors = 2
c = MultiSyncDataCollector(
[env_make] * num_data_collectors, policy=policy, total_frames=int(1e6)
[env_make] * num_data_collectors,
policy=policy,
total_frames=int(1e6),
cat_dim=-1,
)

init_seed = 0
Expand Down Expand Up @@ -856,11 +860,13 @@ def create_env():
collector_class = (
MultiSyncDataCollector if not use_async else MultiaSyncDataCollector
)
kwargs = {"cat_dim": -1} if not use_async else {}
collector = collector_class(
[create_env] * 3,
policy=policy,
devices=[torch.device("cuda:0")] * 3,
storing_devices=[torch.device("cuda:0")] * 3,
**kwargs,
)
# collect state_dict
state_dict = collector.state_dict()
Expand Down Expand Up @@ -933,6 +939,8 @@ def make_env():
collector_kwargs["create_env_fn"] = [
collector_kwargs["create_env_fn"] for _ in range(3)
]
if collector_class is MultiSyncDataCollector:
collector_kwargs["cat_dim"] = -1

collector = collector_class(**collector_kwargs)
collector._exclude_private_keys = exclude
Expand Down Expand Up @@ -1016,6 +1024,8 @@ def test_collector_output_keys(
collector_kwargs["create_env_fn"] = [
collector_kwargs["create_env_fn"] for _ in range(num_envs)
]
if collector_class is MultiSyncDataCollector:
collector_kwargs["cat_dim"] = -1

collector = collector_class(**collector_kwargs)

Expand Down Expand Up @@ -1093,6 +1103,7 @@ def env_fn(seed):
storing_devices=[
storing_device,
],
cat_dim=-1,
)
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
Expand Down Expand Up @@ -1151,6 +1162,8 @@ def _create_collector_kwargs(self, env_maker, collector_class, policy):
collector_kwargs["create_env_fn"] = [
collector_kwargs["create_env_fn"] for _ in range(self.num_envs)
]
if collector_class is MultiSyncDataCollector:
collector_kwargs["cat_dim"] = -1

return collector_kwargs

Expand Down Expand Up @@ -1324,12 +1337,14 @@ def env_fn(seed):
storing_devices="cpu",
split_trajs=False,
preemptive_threshold=0.0, # stop after one iteration
cat_dim=-1,
)

for batch in collector:
trajectory_ids = batch["collector"]["traj_ids"]
trajectory_ids_mask = trajectory_ids != -1 # valid frames mask
assert trajectory_ids[trajectory_ids_mask].numel() < frames_per_batch
assert trajectory_ids_mask.all()
assert trajectory_ids.numel() < frames_per_batch


def test_maxframes_error():
Expand Down Expand Up @@ -1398,6 +1413,7 @@ def test_multi_collector_nested_env_consistency(self, seed=1):
frames_per_batch=20,
total_frames=100,
device="cpu",
cat_dim=-1,
)
for i, d in enumerate(ccollector):
if i == 0:
Expand All @@ -1411,8 +1427,8 @@ def test_multi_collector_nested_env_consistency(self, seed=1):
assert_allclose_td(d1, d2)
ccollector.shutdown()

assert_allclose_td(c1, d1)
assert_allclose_td(c2, d2)
assert_allclose_td(c1, d1.select(*c1.keys(True, True)))
assert_allclose_td(c2, d2.select(*c1.keys(True, True)))

@pytest.mark.parametrize("nested_obs_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
Expand Down Expand Up @@ -1544,6 +1560,7 @@ def test_multi_collector_het_env_consistency(
frames_per_batch=frames_per_batch,
total_frames=100,
device="cpu",
cat_dim=-1,
)
for i, d in enumerate(ccollector):
if i == 0:
Expand All @@ -1557,8 +1574,8 @@ def test_multi_collector_het_env_consistency(
assert_allclose_td(d1, d2)
ccollector.shutdown()

assert_allclose_td(c1, d1)
assert_allclose_td(c2, d2)
assert_allclose_td(c1, d1.select(*c1.keys(True, True)))
assert_allclose_td(c2, d2.select(*c1.keys(True, True)))


class TestMultiKeyEnvsCollector:
Expand Down Expand Up @@ -1619,6 +1636,7 @@ def test_multi_collector_consistency(
frames_per_batch=frames_per_batch,
total_frames=100,
device="cpu",
cat_dim=-1,
)
for i, d in enumerate(ccollector):
if i == 0:
Expand All @@ -1632,8 +1650,8 @@ def test_multi_collector_consistency(
assert_allclose_td(d1, d2)
ccollector.shutdown()

assert_allclose_td(c1, d1)
assert_allclose_td(c2, d2)
assert_allclose_td(c1, d1.select(*c1.keys(True, True)))
assert_allclose_td(c2, d2.select(*c1.keys(True, True)))


@pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda")
Expand Down
193 changes: 182 additions & 11 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,117 @@ def __repr__(self) -> str:
return string


def DataCollector(
create_env_fn: Sequence[Callable[[], EnvBase]],
policy: Optional[
Union[
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
],
*,
num_workers: int = None,
sync: bool = True,
frames_per_batch: int = 200,
total_frames: Optional[int] = -1,
device: DEVICE_TYPING = None,
storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None,
create_env_kwargs: Optional[Sequence[dict]] = None,
max_frames_per_traj: int = -1,
init_random_frames: int = -1,
reset_at_each_iter: bool = False,
postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
split_trajs: Optional[bool] = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
exploration_mode=None,
reset_when_done: bool = True,
preemptive_threshold: float = None,
update_at_each_batch: bool = False,
):
if not isinstance(create_env_fn, EnvBase) and not callable(create_env_fn):
if num_workers is not None and num_workers != len(create_env_fn):
raise TypeError(
"The number of workers provided does not match the number of environment constructors."
)
else:
num_workers = len(create_env_fn)
elif num_workers is not None and num_workers > 0:
create_env_fn = [create_env_fn] * num_workers
from torchrl.envs import EnvCreator

if num_workers and any(
not isinstance(func, (EnvCreator, EnvBase)) for func in create_env_fn
):
create_env_fn = [
func if isinstance(func, (EnvCreator, EnvBase)) else EnvCreator(func)
for func in create_env_fn
]
if num_workers:
if sync:
return MultiSyncDataCollector(
create_env_fn=create_env_fn,
policy=policy,
num_workers=num_workers,
sync=sync,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=storing_device,
create_env_kwargs=create_env_kwargs,
max_frames_per_traj=max_frames_per_traj,
init_random_frames=init_random_frames,
reset_at_each_iter=reset_at_each_iter,
postproc=postproc,
split_trajs=split_trajs,
exploration_type=exploration_type,
exploration_mode=exploration_mode,
reset_when_done=reset_when_done,
preemptive_threshold=preemptive_threshold,
update_at_each_batch=update_at_each_batch,
cat_dim=-1,
)
else:
return MultiaSyncDataCollector(
create_env_fn=create_env_fn,
policy=policy,
num_workers=num_workers,
sync=sync,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=storing_device,
create_env_kwargs=create_env_kwargs,
max_frames_per_traj=max_frames_per_traj,
init_random_frames=init_random_frames,
reset_at_each_iter=reset_at_each_iter,
postproc=postproc,
split_trajs=split_trajs,
exploration_type=exploration_type,
exploration_mode=exploration_mode,
reset_when_done=reset_when_done,
preemptive_threshold=preemptive_threshold,
update_at_each_batch=update_at_each_batch,
)
else:
return SyncDataCollector(
create_env_fn=create_env_fn,
policy=policy,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=storing_device,
create_env_kwargs=create_env_kwargs,
max_frames_per_traj=max_frames_per_traj,
init_random_frames=init_random_frames,
reset_at_each_iter=reset_at_each_iter,
postproc=postproc,
split_trajs=split_trajs,
exploration_type=exploration_type,
exploration_mode=exploration_mode,
reset_when_done=reset_when_done,
)


@accept_remote_rref_udf_invocation
class SyncDataCollector(DataCollectorBase):
"""Generic data collector for RL problems. Requires an environment constructor and a policy.
Expand Down Expand Up @@ -480,6 +591,9 @@ class SyncDataCollector(DataCollectorBase):

"""

def __new__(cls, *args, **kwargs):
return DataCollectorBase.__new__(cls)

def __init__(
self,
create_env_fn: Union[
Expand Down Expand Up @@ -1098,6 +1212,9 @@ class _MultiDataCollector(DataCollectorBase):
that will be allowed to finished collecting their rollout before the rest are forced to end early.
"""

def __new__(cls, *args, **kwargs):
return DataCollectorBase.__new__(cls)

def __init__(
self,
create_env_fn: Sequence[Callable[[], EnvBase]],
Expand Down Expand Up @@ -1590,6 +1707,20 @@ class MultiSyncDataCollector(_MultiDataCollector):

__doc__ += _MultiDataCollector.__doc__

def __init__(self, *args, cat_dim=None, **kwargs):
if cat_dim is None:
warnings.warn(
"The cat_dim argument wasn't passed to the sync data collector. "
"The previous default value was 0 and will stay like "
"this until v0.3. From v0.3, the default will be -1 (the time dimension). In v0.4, "
"this argument will be likely be deprecated."
"To remove this warning, set the cat_dim to -1.",
category=DeprecationWarning,
)
cat_dim = 0
self._cat_dim = cat_dim
super().__init__(*args, **kwargs)

# for RPC
def next(self):
return super().next()
Expand Down Expand Up @@ -1672,7 +1803,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
while self.queue_out.qsize() < int(self.num_workers):
continue

for _ in range(self.num_workers):
for idx in range(self.num_workers):
new_data, j = self.queue_out.get()
if j == 0:
data, idx = new_data
Expand All @@ -1683,14 +1814,16 @@ def iterator(self) -> Iterator[TensorDictBase]:

if workers_frames[idx] >= self.total_frames:
dones[idx] = True
# we have to correct the traj_ids to make sure that they don't overlap
for idx in range(self.num_workers):

traj_ids = self.buffers[idx].get(("collector", "traj_ids"))
if max_traj_idx is not None:
traj_ids[traj_ids != -1] += max_traj_idx
# out_tensordicts_shared[idx].set("traj_ids", traj_ids)
max_traj_idx = traj_ids.max().item() + 1
# out = out_tensordicts_shared[idx]
preempt_mask = traj_ids != -1
if preempt_mask.all():
if max_traj_idx is not None:
traj_ids += max_traj_idx
else:
if max_traj_idx is not None:
traj_ids[traj_ids != -1] += max_traj_idx
max_traj_idx = traj_ids.max().item()
if same_device is None:
prev_device = None
same_device = True
Expand All @@ -1702,20 +1835,58 @@ def iterator(self) -> Iterator[TensorDictBase]:

if same_device:
self.out_buffer = torch.cat(
list(self.buffers.values()), 0, out=self.out_buffer
list(self.buffers.values()), self._cat_dim, out=self.out_buffer
)
else:
self.out_buffer = torch.cat(
[item.cpu() for item in self.buffers.values()],
0,
self._cat_dim,
out=self.out_buffer,
)

if self.split_trajs:
out = split_trajectories(self.out_buffer, prefix="collector")
frames += out.get(("collector", "mask")).sum().item()
else:
out = self.out_buffer.clone()
traj_ids = self.out_buffer.get(("collector", "traj_ids"))
cat_dim = self._cat_dim
if cat_dim < 0:
cat_dim = self.out_buffer.ndim + cat_dim
truncated = None
if cat_dim == self.out_buffer.ndim - 1:
idx = (slice(None),) * (self.out_buffer.ndim - 1)
truncated = (
traj_ids[idx + (slice(None, -1),)]
!= traj_ids[idx + (slice(1),)]
)
truncated = torch.cat(
[
truncated,
torch.ones_like(truncated[idx + (slice(-1, None),)]),
],
self._cat_dim,
)
valid_mask = traj_ids != -1
shape = tuple(
s if i != cat_dim else -1
for i, s in enumerate(self.out_buffer.shape)
)
if not valid_mask.all():
out = self.out_buffer[valid_mask]
if truncated is not None:
out.set(
("next", "truncated"), truncated[valid_mask].unsqueeze(-1)
)
out = out.reshape(shape)
out.names = self.out_buffer.names
else:
out = self.out_buffer.clone()
if truncated is not None:
out.set(
("next", "truncated"),
truncated[valid_mask].reshape(*shape, 1),
)

frames += prod(out.shape)
if self.postprocs:
self.postprocs = self.postprocs.to(out.device)
Expand Down