Skip to content

[BUG] Problems with BatchedEnv on accelerated device with single envs on cpu #1864

Open
@skandermoalla

Description

@skandermoalla

Describe the bug

When the batched env device is cuda the step count on the batched env seems completely off from what it should be.
When the batches env device is mps there is a segmentation fault.

I wonder if this is only the step count that is corrupted or any other data including the observation ...

To Reproduce

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv, SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, step_count_key="single_env_step_count", truncated_key="single_env_truncated"))
    return env

def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )

if __name__ == "__main__":
    env = SerialEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "single_env_step_count"].max().item()
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

On CUDA

Problem!
1065353217

On MPS

python(57380,0x1dd5e5000) malloc: Incorrect checksum for freed object 0x11767f308: probably modified after being freed.
Corrupt value: 0xbd414ea83cfeb221
python(57380,0x1dd5e5000) malloc: *** set a breakpoint in malloc_error_break to debug
[1]    57380 abort      python tests/issue_env_device.py

System info

import torchrl, tensordict, torch, numpy, sys
print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)

2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+01a2216 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

2.2.0 0.4.0+eaef29e 0.4.0+01a2216 1.26.3 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ] darwin

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions