Open
Description
Describe the bug
When the batched env device is cuda
the step count on the batched env seems completely off from what it should be.
When the batches env device is mps
there is a segmentation fault.
I wonder if this is only the step count that is corrupted or any other data including the observation ...
To Reproduce
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
EnvCreator,
ExplorationType,
StepCounter,
TransformedEnv, SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"
def build_cpu_single_env():
env = GymEnv(env_id, device="cpu")
env = TransformedEnv(env)
env.append_transform(StepCounter(max_steps=max_step, step_count_key="single_env_step_count", truncated_key="single_env_truncated"))
return env
def build_actor(env):
return ProbabilisticActor(
module=TensorDictModule(
nn.LazyLinear(env.action_spec.space.n),
in_keys=["observation"],
out_keys=["logits"],
),
spec=env.action_spec,
distribution_class=OneHotCategorical,
in_keys=["logits"],
default_interaction_type=ExplorationType.RANDOM,
)
if __name__ == "__main__":
env = SerialEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
policy_module = build_actor(env)
policy_module.to(device)
policy_module(env.reset())
for i in range(10):
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
max_step_count = batches["next", "single_env_step_count"].max().item()
if max_step_count > max_step:
print("Problem!")
print(max_step_count)
break
else:
print("No problem!")
On CUDA
Problem!
1065353217
On MPS
python(57380,0x1dd5e5000) malloc: Incorrect checksum for freed object 0x11767f308: probably modified after being freed.
Corrupt value: 0xbd414ea83cfeb221
python(57380,0x1dd5e5000) malloc: *** set a breakpoint in malloc_error_break to debug
[1] 57380 abort python tests/issue_env_device.py
System info
import torchrl, tensordict, torch, numpy, sys
print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+01a2216 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux
2.2.0 0.4.0+eaef29e 0.4.0+01a2216 1.26.3 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ] darwin