Skip to content

Commit

Permalink
fix random mask creation in test_maskedtensor (pytorch#97017)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored and pytorchmergebot committed Mar 24, 2023
1 parent 303eb37 commit 7602aad
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions test/test_maskedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,9 @@ def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08):
if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")

def _make_tensor_mask(shape, device):
return make_tensor(
shape, device=device, dtype=torch.bool, low=0, high=1, requires_grad=False
)

def _create_random_mask(shape, device):
return torch.randint(0, 2, shape, device=device).bool()
return make_tensor(shape, device=device, dtype=torch.bool)

def _generate_sample_data(
device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided
Expand All @@ -89,7 +85,7 @@ def _generate_sample_data(
inputs = []
for s in shapes:
data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) # type: ignore[arg-type]
mask = _make_tensor_mask(s, device)
mask = _create_random_mask(s, device)
if layout == torch.sparse_coo:
mask = mask.to_sparse_coo().coalesce()
data = data.sparse_mask(mask).requires_grad_(requires_grad)
Expand Down Expand Up @@ -803,7 +799,7 @@ def _test_unary_binary_equality(self, device, dtype, op, layout=torch.strided):
input = sample.input
sample_args, sample_kwargs = sample.args, sample.kwargs
mask = (
_make_tensor_mask(input.shape, device)
_create_random_mask(input.shape, device)
if "mask" not in sample_kwargs
else sample_kwargs.pop("mask")
)
Expand Down Expand Up @@ -849,7 +845,7 @@ def _test_reduction_equality(self, device, dtype, op, layout=torch.strided):
if input.dim() == 0 or input.numel() == 0:
continue

mask = _make_tensor_mask(input.shape, device)
mask = _create_random_mask(input.shape, device)

if torch.count_nonzero(mask) == 0:
continue
Expand Down

0 comments on commit 7602aad

Please sign in to comment.