@@ -9,48 +9,59 @@ FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,)))
99
1010# ======================================= misc ========================================
1111function trivial_axis (
12- style:: FusionStyle ,
13- :: Val{:codomain} ,
14- a:: AbstractArray ,
12+ style:: FusionStyle , side:: Val{:codomain} , a:: AbstractArray ,
1513 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
1614 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
1715 )
18- return trivial_axis (style, a, axes_codomain, axes_domain)
16+ return throw ( MethodError ( trivial_axis, (style, side, a, axes_codomain, axes_domain)) )
1917end
2018function trivial_axis (
21- style:: FusionStyle ,
22- :: Val{:domain} ,
23- a:: AbstractArray ,
19+ style:: FusionStyle , :: Val{:domain} , a:: AbstractArray ,
2420 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
2521 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
2622 )
27- return trivial_axis (style, a, axes_codomain, axes_domain)
23+ return trivial_axis (style, Val ( :codomain ), a, axes_codomain, axes_domain)
2824end
2925function trivial_axis (
30- style:: FusionStyle ,
31- a:: AbstractArray ,
26+ style:: FusionStyle , a:: AbstractArray ,
3227 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
3328 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
3429 )
35- return trivial_axis (style, a )
30+ return trivial_axis (style, Val ( :codomain ), a, axes_codomain, axes_domain )
3631end
3732function trivial_axis (style:: FusionStyle , a:: AbstractArray )
38- return trivial_axis (ReshapeFusion (), a)
33+ return trivial_axis (style, a, (), ())
34+ end
35+ function trivial_axis (
36+ a:: AbstractArray ,
37+ axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
38+ axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
39+ )
40+ return trivial_axis (FusionStyle (a), a, axes_codomain, axes_domain)
41+ end
42+ function trivial_axis (side:: Val , a:: AbstractArray )
43+ return trivial_axis (FusionStyle (a), side, a)
44+ end
45+ function trivial_axis (a:: AbstractArray )
46+ return trivial_axis (FusionStyle (a), a)
3947end
4048
4149# Tensor product two spaces (ranges) together based on a fusion style.
4250function tensor_product_axis (
43- style:: FusionStyle , :: Val{:codomain} , r1:: AbstractUnitRange , r2:: AbstractUnitRange
51+ style:: FusionStyle , side:: Val{:codomain} ,
52+ r1:: AbstractUnitRange , r2:: AbstractUnitRange ,
4453 )
45- return tensor_product_axis (style, r1, r2)
54+ return throw ( MethodError ( tensor_product_axis, (style, side, r1, r2)) )
4655end
4756function tensor_product_axis (
4857 style:: FusionStyle , :: Val{:domain} , r1:: AbstractUnitRange , r2:: AbstractUnitRange
4958 )
50- return tensor_product_axis (style, r1, r2)
59+ return tensor_product_axis (style, Val ( :codomain ), r1, r2)
5160end
52- function tensor_product_axis (:: FusionStyle , r1:: AbstractUnitRange , r2:: AbstractUnitRange )
53- return tensor_product_axis (ReshapeFusion (), r1, r2)
61+ function tensor_product_axis (
62+ style:: FusionStyle , r1:: AbstractUnitRange , r2:: AbstractUnitRange
63+ )
64+ return tensor_product_axis (style, Val (:codomain ), r1, r2)
5465end
5566function tensor_product_axis (side:: Val , r1:: AbstractUnitRange , r2:: AbstractUnitRange )
5667 style = tensor_product_fusionstyle (r1, r2)
@@ -68,9 +79,7 @@ function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange
6879end
6980
7081function fused_axis (
71- style:: FusionStyle ,
72- side:: Val{:codomain} ,
73- a:: AbstractArray ,
82+ style:: FusionStyle , side:: Val{:codomain} , a:: AbstractArray ,
7483 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
7584 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
7685 )
@@ -80,9 +89,7 @@ function fused_axis(
8089 end
8190end
8291function fused_axis (
83- style:: FusionStyle ,
84- side:: Val{:domain} ,
85- a:: AbstractArray ,
92+ style:: FusionStyle , side:: Val{:domain} , a:: AbstractArray ,
8693 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
8794 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
8895 )
@@ -92,15 +99,21 @@ function fused_axis(
9299 end
93100end
94101function matricize_axes (
95- style:: FusionStyle ,
96- a:: AbstractArray ,
102+ style:: FusionStyle , a:: AbstractArray ,
97103 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
98104 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
99105 )
100106 axis_codomain = fused_axis (style, Val (:codomain ), a, axes_codomain, axes_domain)
101107 axis_domain = fused_axis (style, Val (:domain ), a, axes_codomain, axes_domain)
102108 return axis_codomain, axis_domain
103109end
110+ function matricize_axes (
111+ a:: AbstractArray ,
112+ axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
113+ axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
114+ )
115+ return matricize_axes (FusionStyle (a), a, axes_codomain, axes_domain)
116+ end
104117function matricize_axes (style:: FusionStyle , a:: AbstractArray , ndims_codomain:: Val )
105118 unval (ndims_codomain) ≤ ndims (a) ||
106119 throw (ArgumentError (" Codomain length exceeds number of dimensions." ))
@@ -136,15 +149,15 @@ end
136149# matrix factorizations assume copy
137150# maybe: copy=false kwarg
138151
139- function matricize (a:: AbstractArray , ndims_codomain:: Val )
140- return matricize (FusionStyle (a), a, ndims_codomain)
141- end
142152# This is the primary function that should be overloaded for new fusion styles.
143153# This assumes the permutation was already performed.
144154function matricize (
145155 style:: FusionStyle , a:: AbstractArray , ndims_codomain:: Val
146156 )
147- return matricize (ReshapeFusion (), a, ndims_codomain)
157+ return throw (MethodError (matricize, (style, a, ndims_codomain)))
158+ end
159+ function matricize (a:: AbstractArray , ndims_codomain:: Val )
160+ return matricize (FusionStyle (a), a, ndims_codomain)
148161end
149162
150163function matricize (
@@ -207,20 +220,20 @@ function matricize(
207220end
208221
209222# ==================================== unmatricize =======================================
223+ # This is the primary function that should be overloaded for new fusion styles.
210224function unmatricize (
211- m:: AbstractMatrix ,
225+ style :: FusionStyle , m:: AbstractMatrix ,
212226 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
213227 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
214228 )
215- return unmatricize ( FusionStyle (m), m, axes_codomain, axes_domain)
229+ return throw ( MethodError (unmatricize, (style, m, axes_codomain, axes_domain)) )
216230end
217- # This is the primary function that should be overloaded for new fusion styles.
218231function unmatricize (
219- style :: FusionStyle , m:: AbstractMatrix ,
232+ m:: AbstractMatrix ,
220233 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
221234 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
222235 )
223- return unmatricize (ReshapeFusion ( ), m, axes_codomain, axes_domain)
236+ return unmatricize (FusionStyle (m ), m, axes_codomain, axes_domain)
224237end
225238
226239function unmatricize (m:: AbstractMatrix , blocked_axes:: AbstractBlockTuple{2} )
331344# Defaults to ReshapeFusion, a simple reshape
332345struct ReshapeFusion <: FusionStyle end
333346FusionStyle (:: Type{<:AbstractArray} ) = ReshapeFusion ()
334- trivial_axis (:: ReshapeFusion , a:: AbstractArray ) = Base. OneTo (1 )
335- function tensor_product_axis (:: ReshapeFusion , r1:: AbstractUnitRange , r2:: AbstractUnitRange )
347+ function trivial_axis (
348+ style:: ReshapeFusion , side:: Val{:codomain} , a:: AbstractArray ,
349+ axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
350+ axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
351+ )
352+ return Base. OneTo (1 )
353+ end
354+ function tensor_product_axis (
355+ style:: ReshapeFusion , side:: Val{:codomain} ,
356+ r1:: AbstractUnitRange , r2:: AbstractUnitRange ,
357+ )
336358 (isone (first (r1)) && isone (first (r2))) ||
337359 throw (ArgumentError (" Only one-based axes are supported" ))
338360 return Base. OneTo (length (r1) * length (r2))
0 commit comments