Skip to content

Commit df1069b

Browse files
committed
add tests
1 parent ea602f4 commit df1069b

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

pytorch_sparse_utils/batching/batch_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,14 @@ def padded_to_concatenated(
521521
)
522522
return out, batch_offsets
523523

524+
if padding_mask is not None:
525+
if padding_mask.ndim != 2:
526+
raise ValueError(f"Expected padding_mask to be 2D, got {padding_mask.ndim}")
527+
if padding_mask.shape[0] != batch_size:
528+
raise ValueError("Batch size mismatch between tensor and padding_mask")
529+
if padding_mask.shape[1] != max_len:
530+
raise ValueError("Sequence length mismatch between tensor and padding_mask")
531+
524532
# Early return for no padding: All sequences are same length so can just reshape it
525533
if padding_mask is None or not padding_mask.any():
526534
total_len = batch_size * max_len
@@ -530,13 +538,6 @@ def padded_to_concatenated(
530538

531539
return out, batch_offsets
532540

533-
if padding_mask.ndim != 2:
534-
raise ValueError(f"Expected padding_mask to be 2D, got {padding_mask.ndim}")
535-
if padding_mask.shape[0] != batch_size:
536-
raise ValueError("Batch size mismatch between tensor and padding_mask")
537-
if padding_mask.shape[1] != max_len:
538-
raise ValueError("Sequence length mismatch between tensor and padding_mask")
539-
540541
nonpad_mask = padding_mask.logical_not()
541542
seq_lens = nonpad_mask.sum(-1).to(torch.long)
542543
batch_offsets = seq_lengths_to_batch_offsets(seq_lens)

tests/batching/test_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
split_batch_concatenated_tensor,
2121
)
2222
from pytorch_sparse_utils.validation import validate_atleast_nd
23+
from .. import random_sparse_tensor
2324

2425

2526
@pytest.fixture
@@ -504,6 +505,14 @@ def test_scalar_feature(self, device):
504505
assert torch.equal(result, expected_result)
505506
assert torch.equal(batch_offsets, expected_batch_offsets)
506507

508+
def test_empty_tensor(self, device):
509+
tensor = torch.randn(0, 10, 32, device=device)
510+
result, batch_offsets = padded_to_concatenated(tensor)
511+
512+
assert result.numel() == 0
513+
assert result.shape == (0, 32)
514+
assert torch.equal(batch_offsets, torch.tensor([0], device=device))
515+
507516
def test_error_handling(self, device):
508517
"""Test error handling."""
509518
# Test with tensor with less than 3 dimensions
@@ -523,6 +532,26 @@ def test_error_handling(self, device):
523532
):
524533
padded_to_concatenated(tensor, padding_mask_wrong_batch)
525534

535+
# Wrong padding mask dim
536+
padding_mask_3d = torch.zeros(3, 4, 5, device=device, dtype=torch.bool)
537+
padding_mask_3d[0, -1] = True
538+
with pytest.raises(
539+
(ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType]
540+
match="Expected padding_mask to be 2D",
541+
):
542+
padded_to_concatenated(tensor, padding_mask_3d)
543+
544+
# Sequence length mismatch
545+
padding_mask_wrong_seq_length = torch.zeros(
546+
3, 5, device=device, dtype=torch.bool
547+
)
548+
padding_mask_wrong_seq_length[0, -1] = True
549+
with pytest.raises(
550+
(ValueError, torch.jit.Error), # pyright: ignore[reportArgumentType]
551+
match="Sequence length mismatch",
552+
):
553+
padded_to_concatenated(tensor, padding_mask_wrong_seq_length)
554+
526555

527556
@pytest.mark.cpu_and_cuda
528557
class TestBatchDimToLeadingIndex:
@@ -739,3 +768,25 @@ def test_error_not_sparse(self, device):
739768
match="Received non-sparse tensor",
740769
):
741770
sparse_tensor_to_concatenated(tensor)
771+
772+
773+
class TestConcatenatedToSparseTensor:
774+
def test_basic_functionality(self, device):
775+
"""Test basic functions"""
776+
sparse_tensor = random_sparse_tensor([4, 5, 5], [8], 0.5, seed=0, device=device)
777+
778+
values, indices, batch_offsets = sparse_tensor_to_concatenated(sparse_tensor)
779+
780+
out = concatenated_to_sparse_tensor(values, indices, sparse_tensor.shape)
781+
782+
assert isinstance(out, Tensor)
783+
assert out.is_sparse
784+
785+
assert torch.equal(sparse_tensor.indices(), out.indices())
786+
assert torch.equal(sparse_tensor.values(), out.values())
787+
assert sparse_tensor.shape == out.shape
788+
789+
# Test without shape param
790+
out_2 = concatenated_to_sparse_tensor(values, indices)
791+
assert torch.equal(sparse_tensor.indices(), out_2.indices())
792+
assert torch.equal(sparse_tensor.values(), out_2.values())

0 commit comments

Comments
 (0)