Skip to content

Commit b22523b

Browse files
authored
Upgrade to TensorAlgebra v0.6 (#190)
1 parent 8c89827 commit b22523b

File tree

4 files changed

+26
-39
lines changed

4 files changed

+26
-39
lines changed

Project.toml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
3-
version = "0.10.12"
43
authors = ["ITensor developers <support@itensor.org> and contributors"]
4+
version = "0.10.13"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -22,11 +22,9 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2222

2323
[weakdeps]
2424
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
25-
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2625

2726
[extensions]
2827
BlockSparseArraysTensorAlgebraExt = "TensorAlgebra"
29-
BlockSparseArraysTensorProductsExt = "TensorProducts"
3028

3129
[compat]
3230
Adapt = "4.1.1"
@@ -44,8 +42,7 @@ MapBroadcast = "0.1.5"
4442
MatrixAlgebraKit = "0.6"
4543
SparseArraysBase = "0.7.1"
4644
SplitApplyCombine = "1.2.3"
47-
TensorAlgebra = "0.5"
48-
TensorProducts = "0.1.7"
45+
TensorAlgebra = "0.6"
4946
Test = "1.10"
5047
TypeParameterAccessors = "0.4.1"
5148
julia = "1.10"

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
module BlockSparseArraysTensorAlgebraExt
22

3-
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
4-
using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, fuseaxes
5-
6-
struct BlockReshapeFusion <: FusionStyle end
3+
using BlockArrays: Block, blocklength, blocks, eachblockaxes1
4+
using BlockSparseArrays: AbstractBlockSparseArray, AbstractBlockSparseMatrix,
5+
BlockUnitRange, blockrange, blocksparse
6+
using SparseArraysBase: eachstoredindex
7+
using TensorAlgebra: TensorAlgebra, BlockReshapeFusion, BlockedTuple, matricize,
8+
matricize_axes, tensor_product_axis, unmatricize
79

8-
function TensorAlgebra.FusionStyle(::Type{<:AbstractBlockSparseArray})
9-
return BlockReshapeFusion()
10+
function TensorAlgebra.tensor_product_axis(
11+
::BlockReshapeFusion, r1::BlockUnitRange, r2::BlockUnitRange
12+
)
13+
isone(first(r1)) || isone(first(r2)) ||
14+
throw(ArgumentError("Only one-based axes are supported"))
15+
blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2))
16+
blockaxs = vec(splat(tensor_product_axis).(blockaxpairs))
17+
return blockrange(blockaxs)
1018
end
1119

12-
using BlockArrays: Block, blocklength, blocks
13-
using BlockSparseArrays: blocksparse
14-
using SparseArraysBase: eachstoredindex
15-
using TensorAlgebra: TensorAlgebra, matricize, unmatricize
1620
function TensorAlgebra.matricize(
17-
::BlockReshapeFusion, a::AbstractArray, length1::Val, length2::Val
21+
style::BlockReshapeFusion, a::AbstractBlockSparseArray, length_codomain::Val
1822
)
19-
ax = fuseaxes(axes(a), length1, length2)
20-
reshaped_blocks_a = reshape(blocks(a), map(blocklength, ax))
23+
ax = matricize_axes(style, a, length_codomain)
24+
reshaped_blocks_a = reshape(blocks(a), blocklength.(ax))
2125
key(I) = Block(Tuple(I))
22-
value(I) = matricize(reshaped_blocks_a[I], length1, length2)
26+
value(I) = matricize(reshaped_blocks_a[I], length_codomain)
2327
Is = eachstoredindex(reshaped_blocks_a)
2428
bs = if isempty(Is)
2529
# Catch empty case and make sure the type is constrained properly.
@@ -35,16 +39,16 @@ function TensorAlgebra.matricize(
3539
return blocksparse(bs, ax)
3640
end
3741

38-
using BlockArrays: blocklengths
3942
function TensorAlgebra.unmatricize(
4043
::BlockReshapeFusion,
41-
m::AbstractMatrix,
44+
m::AbstractBlockSparseMatrix,
4245
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
4346
domain_axes::Tuple{Vararg{AbstractUnitRange}},
4447
)
4548
ax = (codomain_axes..., domain_axes...)
46-
reshaped_blocks_m = reshape(blocks(m), map(blocklength, ax))
47-
function f(I)
49+
reshaped_blocks_m = reshape(blocks(m), blocklength.(ax))
50+
key(I) = Block(Tuple(I))
51+
function value(I)
4852
block_axes_I = BlockedTuple(
4953
map(ntuple(identity, length(ax))) do i
5054
return Base.axes1(ax[i][Block(I[i])])
@@ -53,7 +57,7 @@ function TensorAlgebra.unmatricize(
5357
)
5458
return unmatricize(reshaped_blocks_m[I], block_axes_I)
5559
end
56-
bs = Dict(Block(Tuple(I)) => f(I) for I in eachstoredindex(reshaped_blocks_m))
60+
bs = Dict(key(I) => value(I) for I in eachstoredindex(reshaped_blocks_m))
5761
return blocksparse(bs, ax)
5862
end
5963

ext/BlockSparseArraysTensorProductsExt/BlockSparseArraysTensorProductsExt.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ SafeTestsets = "0.1"
4040
SparseArraysBase = "0.7"
4141
StableRNGs = "1"
4242
Suppressor = "0.2"
43-
TensorAlgebra = "0.5"
43+
TensorAlgebra = "0.6"
4444
Test = "1"
4545
TestExtras = "0.3"
4646
TypeParameterAccessors = "0.4"

0 commit comments

Comments
 (0)