Skip to content

Implement unittests for MemmapTensor. #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pytest
import torch
from _utils_internal import get_available_devices
from torchrl.data.tensordict.memmap import MemmapTensor


Expand All @@ -35,7 +36,18 @@ def test_grad():
MemmapTensor(t + 1)


@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.double, torch.bool])
@pytest.mark.parametrize(
"dtype",
[
torch.half,
torch.float,
torch.double,
torch.int,
torch.uint8,
torch.long,
torch.bool,
],
)
@pytest.mark.parametrize(
"shape",
[
Expand All @@ -45,8 +57,9 @@ def test_grad():
[1, 2],
],
)
def test_memmap_metadata(dtype, shape):
t = torch.tensor([1, 0]).reshape(shape)
def test_memmap_data_type(dtype, shape):
"""Test that MemmapTensor can be created with a given data type and shape."""
t = torch.tensor([1, 0], dtype=dtype).reshape(shape)
m = MemmapTensor(t)
assert m.dtype == t.dtype
assert (m == t).all()
Expand Down Expand Up @@ -137,9 +150,49 @@ def test_memmap_clone():
assert m2c == m1


def test_memmap_tensor():
t = torch.tensor([[1, 2, 3], [4, 5, 6]])
assert (torch.tensor(t) == t).all()
@pytest.mark.parametrize("device", get_available_devices())
def test_memmap_same_device_as_tensor(device):
"""
Created MemmapTensor should be on the same device as the input tensor.
Check if device is correct when .to(device) is called.
"""
t = torch.tensor([1], device=device)
m = MemmapTensor(t)
assert m.device == torch.device(device)
for other_device in get_available_devices():
if other_device != device:
with pytest.raises(
RuntimeError,
match="Expected all tensors to be on the same device, "
+ "but found at least two devices",
):
assert torch.all(m + torch.ones([3, 4], device=other_device) == 1)
m = m.to(other_device)
assert m.device == torch.device(other_device)


@pytest.mark.parametrize("device", get_available_devices())
def test_memmap_create_on_same_device(device):
"""Test if the device arg for MemmapTensor init is respected."""
m = MemmapTensor([3, 4], device=device)
assert m.device == torch.device(device)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
"value", [torch.zeros([3, 4]), MemmapTensor(torch.zeros([3, 4]))]
)
@pytest.mark.parametrize("shape", [[3, 4], [[3, 4]]])
def test_memmap_zero_value(device, value, shape):
"""
Test if all entries are zeros when MemmapTensor is created with size.
"""
value = value.to(device)
expected_memmap_tensor = MemmapTensor(value)
m = MemmapTensor(*shape, device=device)
assert m.shape == (3, 4)
assert torch.all(m == expected_memmap_tensor)
assert torch.all(m + torch.ones([3, 4], device=device) == 1)


if __name__ == "__main__":
Expand Down