Skip to content

[Deprecation] Remove InPlaceSampler #2750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from torch import Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
from torchrl._utils import accept_remote_rref_udf_invocation
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
Expand Down Expand Up @@ -1574,33 +1574,12 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:


class InPlaceSampler:
"""A sampler to write tennsordicts in-place.

.. warning:: This class is deprecated and will be removed in v0.7.

To be used cautiously as this may lead to unexpected behavior (i.e. tensordicts
overwritten during execution).

"""
"""[Deprecated] A sampler to write tennsordicts in-place."""

def __init__(self, device: DEVICE_TYPING | None = None):
warnings.warn(
"InPlaceSampler has been deprecated and will be removed in v0.7.",
category=DeprecationWarning,
raise RuntimeError(
"This class has been removed without replacement. In-place sampling should be avoided."
)
self.out = None
if device is None:
device = "cpu"
self.device = _make_ordinal_device(torch.device(device))

def __call__(self, list_of_tds):
if self.out is None:
self.out = torch.stack(list_of_tds, 0).contiguous()
if self.device is not None:
self.out = self.out.to(self.device)
else:
torch.stack(list_of_tds, 0, out=self.out)
return self.out


def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:
Expand Down
Loading