20
20
split_batch_concatenated_tensor ,
21
21
)
22
22
from pytorch_sparse_utils .validation import validate_atleast_nd
23
+ from .. import random_sparse_tensor
23
24
24
25
25
26
@pytest .fixture
@@ -504,6 +505,14 @@ def test_scalar_feature(self, device):
504
505
assert torch .equal (result , expected_result )
505
506
assert torch .equal (batch_offsets , expected_batch_offsets )
506
507
508
+ def test_empty_tensor (self , device ):
509
+ tensor = torch .randn (0 , 10 , 32 , device = device )
510
+ result , batch_offsets = padded_to_concatenated (tensor )
511
+
512
+ assert result .numel () == 0
513
+ assert result .shape == (0 , 32 )
514
+ assert torch .equal (batch_offsets , torch .tensor ([0 ], device = device ))
515
+
507
516
def test_error_handling (self , device ):
508
517
"""Test error handling."""
509
518
# Test with tensor with less than 3 dimensions
@@ -523,6 +532,26 @@ def test_error_handling(self, device):
523
532
):
524
533
padded_to_concatenated (tensor , padding_mask_wrong_batch )
525
534
535
+ # Wrong padding mask dim
536
+ padding_mask_3d = torch .zeros (3 , 4 , 5 , device = device , dtype = torch .bool )
537
+ padding_mask_3d [0 , - 1 ] = True
538
+ with pytest .raises (
539
+ (ValueError , torch .jit .Error ), # pyright: ignore[reportArgumentType]
540
+ match = "Expected padding_mask to be 2D" ,
541
+ ):
542
+ padded_to_concatenated (tensor , padding_mask_3d )
543
+
544
+ # Sequence length mismatch
545
+ padding_mask_wrong_seq_length = torch .zeros (
546
+ 3 , 5 , device = device , dtype = torch .bool
547
+ )
548
+ padding_mask_wrong_seq_length [0 , - 1 ] = True
549
+ with pytest .raises (
550
+ (ValueError , torch .jit .Error ), # pyright: ignore[reportArgumentType]
551
+ match = "Sequence length mismatch" ,
552
+ ):
553
+ padded_to_concatenated (tensor , padding_mask_wrong_seq_length )
554
+
526
555
527
556
@pytest .mark .cpu_and_cuda
528
557
class TestBatchDimToLeadingIndex :
@@ -739,3 +768,25 @@ def test_error_not_sparse(self, device):
739
768
match = "Received non-sparse tensor" ,
740
769
):
741
770
sparse_tensor_to_concatenated (tensor )
771
+
772
+
773
+ class TestConcatenatedToSparseTensor :
774
+ def test_basic_functionality (self , device ):
775
+ """Test basic functions"""
776
+ sparse_tensor = random_sparse_tensor ([4 , 5 , 5 ], [8 ], 0.5 , seed = 0 , device = device )
777
+
778
+ values , indices , batch_offsets = sparse_tensor_to_concatenated (sparse_tensor )
779
+
780
+ out = concatenated_to_sparse_tensor (values , indices , sparse_tensor .shape )
781
+
782
+ assert isinstance (out , Tensor )
783
+ assert out .is_sparse
784
+
785
+ assert torch .equal (sparse_tensor .indices (), out .indices ())
786
+ assert torch .equal (sparse_tensor .values (), out .values ())
787
+ assert sparse_tensor .shape == out .shape
788
+
789
+ # Test without shape param
790
+ out_2 = concatenated_to_sparse_tensor (values , indices )
791
+ assert torch .equal (sparse_tensor .indices (), out_2 .indices ())
792
+ assert torch .equal (sparse_tensor .values (), out_2 .values ())
0 commit comments