Skip to content

Implement unittests for MemmapTensor. #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 29, 2022
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pytest
import torch
from _utils_internal import get_available_devices
from torchrl.data.tensordict.memmap import MemmapTensor


Expand All @@ -35,7 +36,18 @@ def test_grad():
MemmapTensor(t + 1)


@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.double, torch.bool])
@pytest.mark.parametrize(
"dtype",
[
torch.half,
torch.float,
torch.double,
torch.int,
torch.uint8,
torch.long,
torch.bool,
],
)
@pytest.mark.parametrize(
"shape",
[
Expand All @@ -45,8 +57,9 @@ def test_grad():
[1, 2],
],
)
def test_memmap_metadata(dtype, shape):
t = torch.tensor([1, 0]).reshape(shape)
def test_memmap_data_type(dtype, shape):
"""Test that MemmapTensor can be created with a given data type and shape."""
t = torch.tensor([1, 0], dtype=dtype).reshape(shape)
m = MemmapTensor(t)
assert m.dtype == t.dtype
assert (m == t).all()
Expand Down Expand Up @@ -137,9 +150,52 @@ def test_memmap_clone():
assert m2c == m1


def test_memmap_tensor():
t = torch.tensor([[1, 2, 3], [4, 5, 6]])
assert (torch.tensor(t) == t).all()
@pytest.mark.parametrize("device", get_available_devices())
def test_memmap_same_device_as_tensor(device):
"""
Created MemmapTensor should be on the same device as the input tensor.
Check if device is correct when .to(device) is called.
"""
t = torch.tensor([1], device=device)
m = MemmapTensor(t)
assert m.device == torch.device(device)
for other_device in get_available_devices():
if other_device != device:
with pytest.raises(
RuntimeError,
match="Expected all tensors to be on the same device, "
+ "but found at least two devices",
):
assert torch.all(m + torch.ones([3, 4], device=other_device) == 1)
m.to(other_device)
assert m.device == torch.device(other_device)


@pytest.mark.parametrize("device", get_available_devices())
def test_memmap_create_on_same_device(device):
"""Test if the device arg for MemmapTensor init is respected."""
m = MemmapTensor([3, 4], device=device)
assert m.device == torch.device(device)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
"value", [torch.zeros([3, 4]), MemmapTensor(torch.zeros([3, 4]))]
)
def test_memmap_zero_value(device, value):
"""
Test if all entries are zeros when MemmapTensor is created with size.
"""
value.to(device)
expected_memmap_tensor = MemmapTensor(value)
m1 = MemmapTensor([3, 4])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guss this guy should be created on device too

    m1 = MemmapTensor([3, 4], device=device)

assert m1.shape == (3, 4)
assert torch.all(m1 == expected_memmap_tensor)
assert torch.all(m1 + torch.ones([3, 4], device=device) == 1)
m2 = MemmapTensor(3, 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here
by the way, we could perhaps nest those two tests using parameters

@pytest.mark.parametrize("shape", [ [[3, 4]], [3, 4] ])
def test_foo(shape, device):
    m1 = MemmapTensor(*shape, device=device)  # first cast is a list, second case is a sequence of integers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @vmoens!
I have made corresponding changes. I was setting up the WSL environment locally on a windows machine to test cuda/cpu. Sorry for delayed responses.

assert m2.shape == (3, 4)
assert torch.all(m2 == expected_memmap_tensor)
assert torch.all(m2 + torch.ones([3, 4], device=device) == 1)


if __name__ == "__main__":
Expand Down