11module BlockSparseArraysTensorAlgebraExt
22
3- using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
4- using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, fuseaxes
5-
6- struct BlockReshapeFusion <: FusionStyle end
3+ using BlockArrays: Block, blocklength, blocks, eachblockaxes1
4+ using BlockSparseArrays: AbstractBlockSparseArray, AbstractBlockSparseMatrix,
5+ BlockUnitRange, blockrange, blocksparse
6+ using SparseArraysBase: eachstoredindex
7+ using TensorAlgebra: TensorAlgebra, BlockReshapeFusion, BlockedTuple, matricize,
8+ matricize_axes, tensor_product_axis, unmatricize
79
8- function TensorAlgebra. FusionStyle (:: Type{<:AbstractBlockSparseArray} )
9- return BlockReshapeFusion ()
10+ function TensorAlgebra. tensor_product_axis (
11+ :: BlockReshapeFusion , r1:: BlockUnitRange , r2:: BlockUnitRange
12+ )
13+ isone (first (r1)) || isone (first (r2)) ||
14+ throw (ArgumentError (" Only one-based axes are supported" ))
15+ blockaxpairs = Iterators. product (eachblockaxes1 (r1), eachblockaxes1 (r2))
16+ blockaxs = vec (splat (tensor_product_axis).(blockaxpairs))
17+ return blockrange (blockaxs)
1018end
1119
12- using BlockArrays: Block, blocklength, blocks
13- using BlockSparseArrays: blocksparse
14- using SparseArraysBase: eachstoredindex
15- using TensorAlgebra: TensorAlgebra, matricize, unmatricize
1620function TensorAlgebra. matricize (
17- :: BlockReshapeFusion , a:: AbstractArray , length1 :: Val , length2 :: Val
21+ style :: BlockReshapeFusion , a:: AbstractBlockSparseArray , length_codomain :: Val
1822 )
19- ax = fuseaxes ( axes (a), length1, length2 )
20- reshaped_blocks_a = reshape (blocks (a), map ( blocklength, ax))
23+ ax = matricize_axes (style, a, length_codomain )
24+ reshaped_blocks_a = reshape (blocks (a), blocklength .( ax))
2125 key (I) = Block (Tuple (I))
22- value (I) = matricize (reshaped_blocks_a[I], length1, length2 )
26+ value (I) = matricize (reshaped_blocks_a[I], length_codomain )
2327 Is = eachstoredindex (reshaped_blocks_a)
2428 bs = if isempty (Is)
2529 # Catch empty case and make sure the type is constrained properly.
@@ -35,16 +39,16 @@ function TensorAlgebra.matricize(
3539 return blocksparse (bs, ax)
3640end
3741
38- using BlockArrays: blocklengths
3942function TensorAlgebra. unmatricize (
4043 :: BlockReshapeFusion ,
41- m:: AbstractMatrix ,
44+ m:: AbstractBlockSparseMatrix ,
4245 codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
4346 domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
4447 )
4548 ax = (codomain_axes... , domain_axes... )
46- reshaped_blocks_m = reshape (blocks (m), map (blocklength, ax))
47- function f (I)
49+ reshaped_blocks_m = reshape (blocks (m), blocklength .(ax))
50+ key (I) = Block (Tuple (I))
51+ function value (I)
4852 block_axes_I = BlockedTuple (
4953 map (ntuple (identity, length (ax))) do i
5054 return Base. axes1 (ax[i][Block (I[i])])
@@ -53,7 +57,7 @@ function TensorAlgebra.unmatricize(
5357 )
5458 return unmatricize (reshaped_blocks_m[I], block_axes_I)
5559 end
56- bs = Dict (Block ( Tuple (I)) => f (I) for I in eachstoredindex (reshaped_blocks_m))
60+ bs = Dict (key (I) => value (I) for I in eachstoredindex (reshaped_blocks_m))
5761 return blocksparse (bs, ax)
5862end
5963
0 commit comments