Skip to content

Commit 14fa276

Browse files
authored
Rework the matricize interface (#104)
1 parent 151f271 commit 14fa276

File tree

3 files changed

+98
-50
lines changed

3 files changed

+98
-50
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.6.2"
4+
version = "0.6.3"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/blockarrays.jl

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@ using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, BlockedArray, b
44
struct BlockReshapeFusion <: FusionStyle end
55
FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion()
66

7-
trivial_axis(::BlockReshapeFusion, a::AbstractArray) = blockedrange([1])
7+
function trivial_axis(
8+
style::BlockReshapeFusion, side::Val{:codomain}, a::AbstractArray,
9+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
10+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
11+
)
12+
return blockedrange([1])
13+
end
814
function mortar_axis(axs)
915
all(isone first, axs) ||
1016
throw(ArgumentError("Only one-based axes are supported"))
1117
return blockedrange(length.(axs))
1218
end
1319
function tensor_product_axis(
14-
::BlockReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange
20+
style::BlockReshapeFusion, side::Val{:codomain},
21+
r1::AbstractUnitRange, r2::AbstractUnitRange,
1522
)
1623
(isone(first(r1)) && isone(first(r2))) ||
1724
throw(ArgumentError("Only one-based axes are supported"))
@@ -29,35 +36,33 @@ function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::
2936
end
3037
using BlockArrays: blocklengths
3138
function unmatricize(
32-
::BlockReshapeFusion,
33-
m::AbstractMatrix,
34-
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
35-
domain_axes::Tuple{Vararg{AbstractUnitRange}},
39+
::BlockReshapeFusion, m::AbstractMatrix,
40+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
41+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
3642
)
37-
ax = (codomain_axes..., domain_axes...)
43+
ax = (axes_codomain..., axes_domain...)
3844
reshaped_blocks_m = reshape(blocks(m), blocklength.(ax))
3945
bs = map(CartesianIndices(reshaped_blocks_m)) do I
4046
block_axes_I = BlockedTuple(
4147
map(ntuple(identity, length(ax))) do i
4248
return Base.axes1(ax[i][Block(I[i])])
4349
end,
44-
(length(codomain_axes), length(domain_axes)),
50+
(length(axes_codomain), length(axes_domain)),
4551
)
4652
return unmatricize(reshaped_blocks_m[I], block_axes_I)
4753
end
4854
return mortar(bs, ax)
4955
end
5056

51-
struct BlockedReshapeFusion <: FusionStyle end
52-
FusionStyle(::Type{<:BlockedArray}) = BlockedReshapeFusion()
57+
FusionStyle(::Type{<:BlockedArray}) = ReshapeFusion()
5358
unblock(a::BlockedArray) = a.blocks
5459
unblock(a::AbstractBlockArray) = a[Base.OneTo.(size(a))...]
5560
unblock(a::AbstractArray) = a
56-
function matricize(::BlockedReshapeFusion, a::AbstractArray, ndims_codomain::Val)
61+
function matricize(::ReshapeFusion, a::BlockedArray, ndims_codomain::Val)
5762
return matricize(ReshapeFusion(), unblock(a), ndims_codomain)
5863
end
59-
function unmatricize(
60-
style::BlockedReshapeFusion, m::AbstractMatrix,
64+
function unmatricize_blocked(
65+
style::ReshapeFusion, m::AbstractMatrix,
6166
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
6267
axes_domain::Tuple{Vararg{AbstractUnitRange}},
6368
)
@@ -67,3 +72,24 @@ function unmatricize(
6772
)
6873
return BlockedArray(a, (axes_codomain..., axes_domain...))
6974
end
75+
function unmatricize(
76+
style::ReshapeFusion, m::AbstractMatrix,
77+
axes_codomain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
78+
axes_domain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
79+
)
80+
return unmatricize_blocked(style, m, axes_codomain, axes_domain)
81+
end
82+
function unmatricize(
83+
style::ReshapeFusion, m::AbstractMatrix,
84+
axes_codomain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
85+
axes_domain::Tuple{Vararg{AbstractBlockedUnitRange}},
86+
)
87+
return unmatricize_blocked(style, m, axes_codomain, axes_domain)
88+
end
89+
function unmatricize(
90+
style::ReshapeFusion, m::AbstractMatrix,
91+
axes_codomain::Tuple{Vararg{AbstractBlockedUnitRange}},
92+
axes_domain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
93+
)
94+
return unmatricize_blocked(style, m, axes_codomain, axes_domain)
95+
end

src/matricize.jl

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,48 +9,59 @@ FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,)))
99

1010
# ======================================= misc ========================================
1111
function trivial_axis(
12-
style::FusionStyle,
13-
::Val{:codomain},
14-
a::AbstractArray,
12+
style::FusionStyle, side::Val{:codomain}, a::AbstractArray,
1513
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
1614
axes_domain::Tuple{Vararg{AbstractUnitRange}},
1715
)
18-
return trivial_axis(style, a, axes_codomain, axes_domain)
16+
return throw(MethodError(trivial_axis, (style, side, a, axes_codomain, axes_domain)))
1917
end
2018
function trivial_axis(
21-
style::FusionStyle,
22-
::Val{:domain},
23-
a::AbstractArray,
19+
style::FusionStyle, ::Val{:domain}, a::AbstractArray,
2420
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
2521
axes_domain::Tuple{Vararg{AbstractUnitRange}},
2622
)
27-
return trivial_axis(style, a, axes_codomain, axes_domain)
23+
return trivial_axis(style, Val(:codomain), a, axes_codomain, axes_domain)
2824
end
2925
function trivial_axis(
30-
style::FusionStyle,
31-
a::AbstractArray,
26+
style::FusionStyle, a::AbstractArray,
3227
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
3328
axes_domain::Tuple{Vararg{AbstractUnitRange}},
3429
)
35-
return trivial_axis(style, a)
30+
return trivial_axis(style, Val(:codomain), a, axes_codomain, axes_domain)
3631
end
3732
function trivial_axis(style::FusionStyle, a::AbstractArray)
38-
return trivial_axis(ReshapeFusion(), a)
33+
return trivial_axis(style, a, (), ())
34+
end
35+
function trivial_axis(
36+
a::AbstractArray,
37+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
38+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
39+
)
40+
return trivial_axis(FusionStyle(a), a, axes_codomain, axes_domain)
41+
end
42+
function trivial_axis(side::Val, a::AbstractArray)
43+
return trivial_axis(FusionStyle(a), side, a)
44+
end
45+
function trivial_axis(a::AbstractArray)
46+
return trivial_axis(FusionStyle(a), a)
3947
end
4048

4149
# Tensor product two spaces (ranges) together based on a fusion style.
4250
function tensor_product_axis(
43-
style::FusionStyle, ::Val{:codomain}, r1::AbstractUnitRange, r2::AbstractUnitRange
51+
style::FusionStyle, side::Val{:codomain},
52+
r1::AbstractUnitRange, r2::AbstractUnitRange,
4453
)
45-
return tensor_product_axis(style, r1, r2)
54+
return throw(MethodError(tensor_product_axis, (style, side, r1, r2)))
4655
end
4756
function tensor_product_axis(
4857
style::FusionStyle, ::Val{:domain}, r1::AbstractUnitRange, r2::AbstractUnitRange
4958
)
50-
return tensor_product_axis(style, r1, r2)
59+
return tensor_product_axis(style, Val(:codomain), r1, r2)
5160
end
52-
function tensor_product_axis(::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange)
53-
return tensor_product_axis(ReshapeFusion(), r1, r2)
61+
function tensor_product_axis(
62+
style::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange
63+
)
64+
return tensor_product_axis(style, Val(:codomain), r1, r2)
5465
end
5566
function tensor_product_axis(side::Val, r1::AbstractUnitRange, r2::AbstractUnitRange)
5667
style = tensor_product_fusionstyle(r1, r2)
@@ -68,9 +79,7 @@ function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange
6879
end
6980

7081
function fused_axis(
71-
style::FusionStyle,
72-
side::Val{:codomain},
73-
a::AbstractArray,
82+
style::FusionStyle, side::Val{:codomain}, a::AbstractArray,
7483
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
7584
axes_domain::Tuple{Vararg{AbstractUnitRange}},
7685
)
@@ -80,9 +89,7 @@ function fused_axis(
8089
end
8190
end
8291
function fused_axis(
83-
style::FusionStyle,
84-
side::Val{:domain},
85-
a::AbstractArray,
92+
style::FusionStyle, side::Val{:domain}, a::AbstractArray,
8693
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
8794
axes_domain::Tuple{Vararg{AbstractUnitRange}},
8895
)
@@ -92,15 +99,21 @@ function fused_axis(
9299
end
93100
end
94101
function matricize_axes(
95-
style::FusionStyle,
96-
a::AbstractArray,
102+
style::FusionStyle, a::AbstractArray,
97103
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
98104
axes_domain::Tuple{Vararg{AbstractUnitRange}},
99105
)
100106
axis_codomain = fused_axis(style, Val(:codomain), a, axes_codomain, axes_domain)
101107
axis_domain = fused_axis(style, Val(:domain), a, axes_codomain, axes_domain)
102108
return axis_codomain, axis_domain
103109
end
110+
function matricize_axes(
111+
a::AbstractArray,
112+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
113+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
114+
)
115+
return matricize_axes(FusionStyle(a), a, axes_codomain, axes_domain)
116+
end
104117
function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Val)
105118
unval(ndims_codomain) ndims(a) ||
106119
throw(ArgumentError("Codomain length exceeds number of dimensions."))
@@ -136,15 +149,15 @@ end
136149
# matrix factorizations assume copy
137150
# maybe: copy=false kwarg
138151

139-
function matricize(a::AbstractArray, ndims_codomain::Val)
140-
return matricize(FusionStyle(a), a, ndims_codomain)
141-
end
142152
# This is the primary function that should be overloaded for new fusion styles.
143153
# This assumes the permutation was already performed.
144154
function matricize(
145155
style::FusionStyle, a::AbstractArray, ndims_codomain::Val
146156
)
147-
return matricize(ReshapeFusion(), a, ndims_codomain)
157+
return throw(MethodError(matricize, (style, a, ndims_codomain)))
158+
end
159+
function matricize(a::AbstractArray, ndims_codomain::Val)
160+
return matricize(FusionStyle(a), a, ndims_codomain)
148161
end
149162

150163
function matricize(
@@ -207,20 +220,20 @@ function matricize(
207220
end
208221

209222
# ==================================== unmatricize =======================================
223+
# This is the primary function that should be overloaded for new fusion styles.
210224
function unmatricize(
211-
m::AbstractMatrix,
225+
style::FusionStyle, m::AbstractMatrix,
212226
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
213227
axes_domain::Tuple{Vararg{AbstractUnitRange}},
214228
)
215-
return unmatricize(FusionStyle(m), m, axes_codomain, axes_domain)
229+
return throw(MethodError(unmatricize, (style, m, axes_codomain, axes_domain)))
216230
end
217-
# This is the primary function that should be overloaded for new fusion styles.
218231
function unmatricize(
219-
style::FusionStyle, m::AbstractMatrix,
232+
m::AbstractMatrix,
220233
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
221234
axes_domain::Tuple{Vararg{AbstractUnitRange}},
222235
)
223-
return unmatricize(ReshapeFusion(), m, axes_codomain, axes_domain)
236+
return unmatricize(FusionStyle(m), m, axes_codomain, axes_domain)
224237
end
225238

226239
function unmatricize(m::AbstractMatrix, blocked_axes::AbstractBlockTuple{2})
@@ -331,8 +344,17 @@ end
331344
# Defaults to ReshapeFusion, a simple reshape
332345
struct ReshapeFusion <: FusionStyle end
333346
FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion()
334-
trivial_axis(::ReshapeFusion, a::AbstractArray) = Base.OneTo(1)
335-
function tensor_product_axis(::ReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange)
347+
function trivial_axis(
348+
style::ReshapeFusion, side::Val{:codomain}, a::AbstractArray,
349+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
350+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
351+
)
352+
return Base.OneTo(1)
353+
end
354+
function tensor_product_axis(
355+
style::ReshapeFusion, side::Val{:codomain},
356+
r1::AbstractUnitRange, r2::AbstractUnitRange,
357+
)
336358
(isone(first(r1)) && isone(first(r2))) ||
337359
throw(ArgumentError("Only one-based axes are supported"))
338360
return Base.OneTo(length(r1) * length(r2))

0 commit comments

Comments
 (0)