-
Notifications
You must be signed in to change notification settings - Fork 373
Implement unittests for MemmapTensor. #231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
1392122
f8b96e0
a70c038
09429a7
9fdc245
c9e7921
984fe63
8558839
ceac8cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
import numpy as np | ||
import pytest | ||
import torch | ||
from _utils_internal import get_available_devices | ||
from torchrl.data.tensordict.memmap import MemmapTensor | ||
|
||
|
||
|
@@ -35,7 +36,18 @@ def test_grad(): | |
MemmapTensor(t + 1) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.double, torch.bool]) | ||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
torch.half, | ||
torch.float, | ||
torch.double, | ||
torch.int, | ||
torch.uint8, | ||
torch.long, | ||
torch.bool, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"shape", | ||
[ | ||
|
@@ -45,8 +57,9 @@ def test_grad(): | |
[1, 2], | ||
], | ||
) | ||
def test_memmap_metadata(dtype, shape): | ||
t = torch.tensor([1, 0]).reshape(shape) | ||
def test_memmap_data_type(dtype, shape): | ||
"""Test that MemmapTensor can be created with a given data type and shape.""" | ||
t = torch.tensor([1, 0], dtype=dtype).reshape(shape) | ||
m = MemmapTensor(t) | ||
assert m.dtype == t.dtype | ||
assert (m == t).all() | ||
|
@@ -137,9 +150,52 @@ def test_memmap_clone(): | |
assert m2c == m1 | ||
|
||
|
||
def test_memmap_tensor(): | ||
t = torch.tensor([[1, 2, 3], [4, 5, 6]]) | ||
assert (torch.tensor(t) == t).all() | ||
@pytest.mark.parametrize("device", get_available_devices()) | ||
def test_memmap_same_device_as_tensor(device): | ||
""" | ||
Created MemmapTensor should be on the same device as the input tensor. | ||
Check if device is correct when .to(device) is called. | ||
""" | ||
t = torch.tensor([1], device=device) | ||
m = MemmapTensor(t) | ||
assert m.device == torch.device(device) | ||
for other_device in get_available_devices(): | ||
if other_device != device: | ||
with pytest.raises( | ||
RuntimeError, | ||
match="Expected all tensors to be on the same device, " | ||
+ "but found at least two devices", | ||
): | ||
assert torch.all(m + torch.ones([3, 4], device=other_device) == 1) | ||
m.to(other_device) | ||
assert m.device == torch.device(other_device) | ||
|
||
|
||
@pytest.mark.parametrize("device", get_available_devices()) | ||
def test_memmap_create_on_same_device(device): | ||
"""Test if the device arg for MemmapTensor init is respected.""" | ||
m = MemmapTensor([3, 4], device=device) | ||
assert m.device == torch.device(device) | ||
fredfung007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.parametrize("device", get_available_devices()) | ||
@pytest.mark.parametrize( | ||
"value", [torch.zeros([3, 4]), MemmapTensor(torch.zeros([3, 4]))] | ||
) | ||
def test_memmap_zero_value(device, value): | ||
""" | ||
Test if all entries are zeros when MemmapTensor is created with size. | ||
""" | ||
value.to(device) | ||
expected_memmap_tensor = MemmapTensor(value) | ||
m1 = MemmapTensor([3, 4]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guss this guy should be created on device too m1 = MemmapTensor([3, 4], device=device) |
||
assert m1.shape == (3, 4) | ||
assert torch.all(m1 == expected_memmap_tensor) | ||
assert torch.all(m1 + torch.ones([3, 4], device=device) == 1) | ||
m2 = MemmapTensor(3, 4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here @pytest.mark.parametrize("shape", [ [[3, 4]], [3, 4] ])
def test_foo(shape, device):
m1 = MemmapTensor(*shape, device=device) # first cast is a list, second case is a sequence of integers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @vmoens! |
||
assert m2.shape == (3, 4) | ||
assert torch.all(m2 == expected_memmap_tensor) | ||
assert torch.all(m2 + torch.ones([3, 4], device=device) == 1) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Uh oh!
There was an error while loading. Please reload this page.