-
Notifications
You must be signed in to change notification settings - Fork 26
/
to_vec.jl
320 lines (273 loc) · 9.1 KB
/
to_vec.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""
to_vec(x)
Transform `x` into a `Vector`, and return the vector, and a closure which inverts the
transformation.
"""
function to_vec(x::Real)
function Real_from_vec(x_vec)
return first(x_vec)
end
return [x], Real_from_vec
end
function to_vec(z::Complex)
function Complex_from_vec(z_vec)
return Complex(z_vec[1], z_vec[2])
end
return [real(z), imag(z)], Complex_from_vec
end
# Base case -- if x is already a Vector{<:Real} there's no conversion necessary.
to_vec(x::Vector{<:Real}) = (x, identity)
# get around the constructors and make the type directly
# Note this is moderately evil accessing julia's internals
if VERSION >= v"1.3"
@generated function _force_construct(T, args...)
return Expr(:splatnew, :T, :args)
end
else
@generated function _force_construct(T, args...)
return Expr(:new, :T, Any[:(args[$i]) for i in 1:length(args)]...)
end
end
# Fallback method for `to_vec`. Won't always do what you wanted, but should be fine a decent
# chunk of the time.
function to_vec(x::T) where {T}
Base.isstructtype(T) || throw(error("Expected a struct type"))
isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types
val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T))
vals = first.(val_vecs_and_backs)
backs = last.(val_vecs_and_backs)
v, vals_from_vec = to_vec(vals)
function structtype_from_vec(v::Vector{<:Real})
val_vecs = vals_from_vec(v)
values = map((b, v) -> b(v), backs, val_vecs)
try
T(values...)
catch MethodError
return _force_construct(T, values...)
end
end
return v, structtype_from_vec
end
function to_vec(x::DenseVector)
x_vecs_and_backs = map(to_vec, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
sz = cumsum(map(length, x_vecs))
x_Vec = [backs[n](x_vec[sz[n] - length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
return oftype(x, x_Vec)
end
# handle empty x
x_vec = isempty(x_vecs) ? eltype(eltype(x_vecs))[] : reduce(vcat, x_vecs)
return x_vec, Vector_from_vec
end
function to_vec(x::DenseArray)
x_vec, from_vec = to_vec(vec(x))
function Array_from_vec(x_vec)
return oftype(x, reshape(from_vec(x_vec), size(x)))
end
return x_vec, Array_from_vec
end
# Some specific subtypes of AbstractArray.
function to_vec(x::Base.ReshapedArray{<:Any, 1})
x_vec, from_vec = to_vec(parent(x))
function ReshapedArray_from_vec(x_vec)
p = from_vec(x_vec)
return Base.ReshapedArray(p, x.dims, x.mi)
end
return x_vec, ReshapedArray_from_vec
end
# To return a SubArray we would endup needing to copy the `parent` of `x` in `from_vec`
# which doesn't seem particularly useful. So we just convert the view into a copy.
# we might be able to do something more performant but this seems good for now.
to_vec(x::Base.SubArray) = to_vec(copy(x))
function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
x_vec, back = to_vec(Matrix(x))
function AbstractTriangular_from_vec(x_vec)
return T(reshape(back(x_vec), size(x)))
end
return x_vec, AbstractTriangular_from_vec
end
function to_vec(x::T) where {T<:LinearAlgebra.HermOrSym}
x_vec, back = to_vec(Matrix(x))
function HermOrSym_from_vec(x_vec)
return T(back(x_vec), x.uplo)
end
return x_vec, HermOrSym_from_vec
end
function to_vec(x::Diagonal)
x_vec, back = to_vec(Matrix(x))
function Diagonal_from_vec(x_vec)
return Diagonal(back(x_vec))
end
return x_vec, Diagonal_from_vec
end
function to_vec(x::Tridiagonal)
x_vec, back = to_vec((x.dl, x.d, x.du))
# Other field (du2) of a Tridiagonal is not part of its value and is really a kind of cache
function Tridiagonal_from_vec(x_vec)
return Tridiagonal(back(x_vec)...)
end
return x_vec, Tridiagonal_from_vec
end
function to_vec(X::Transpose)
x_vec, back = to_vec(Matrix(X))
function Transpose_from_vec(x_vec)
return Transpose(permutedims(back(x_vec)))
end
return x_vec, Transpose_from_vec
end
function to_vec(x::Transpose{<:Any, <:AbstractVector})
x_vec, back = to_vec(Matrix(x))
Transpose_from_vec(x_vec) = Transpose(vec(back(x_vec)))
return x_vec, Transpose_from_vec
end
function to_vec(X::Adjoint)
x_vec, back = to_vec(Matrix(X))
function Adjoint_from_vec(x_vec)
return Adjoint(conj!(permutedims(back(x_vec))))
end
return x_vec, Adjoint_from_vec
end
function to_vec(x::Adjoint{<:Any, <:AbstractVector})
x_vec, back = to_vec(Matrix(x))
Adjoint_from_vec(x_vec) = Adjoint(conj!(vec(back(x_vec))))
return x_vec, Adjoint_from_vec
end
function to_vec(X::T) where {T<:PermutedDimsArray}
x_vec, back = to_vec(parent(X))
function PermutedDimsArray_from_vec(x_vec)
X_parent = back(x_vec)
return T(X_parent)
end
return x_vec, PermutedDimsArray_from_vec
end
function to_vec(v::SparseVector)
inds, _ = findnz(v)
sizes = size(v)
x_vec, back = to_vec(collect(v))
function SparseVector_from_vec(x_v)
v_values = back(x_v)
return sparsevec(inds, v_values[inds], sizes...)
end
return x_vec, SparseVector_from_vec
end
function to_vec(m::SparseMatrixCSC)
is, js, _ = findnz(m)
sizes = size(m)
x_vec, back = to_vec(collect(m))
function SparseMatrixCSC_from_vec(x_v)
v_values = back(x_v)
return sparse(is, js, [v_values[i, j] for (i, j) in zip(is, js)], sizes...)
end
return x_vec, SparseMatrixCSC_from_vec
end
# Factorizations
function to_vec(x::F) where {F <: SVD}
# Convert the vector S to a matrix so we can work with a vector of matrices
# only and inference work
v = [x.U, reshape(x.S, length(x.S), 1), x.Vt]
x_vec, back = to_vec(v)
function SVD_from_vec(v)
U, Smat, Vt = back(v)
return F(U, vec(Smat), Vt)
end
return x_vec, SVD_from_vec
end
function to_vec(x::Cholesky)
x_vec, back = to_vec(x.factors)
function Cholesky_from_vec(v)
return Cholesky(back(v), x.uplo, x.info)
end
return x_vec, Cholesky_from_vec
end
function to_vec(x::S) where {U, S <: Union{LinearAlgebra.QRCompactWYQ{U}, LinearAlgebra.QRCompactWY{U}}}
# x.T is composed of upper triangular blocks. The subdiagonals elements
# of the blocks are arbitrary. We make sure to set all of them to zero
# to avoid NaN.
blocksize, cols = size(x.T)
T = zeros(U, blocksize, cols)
for i in 0:div(cols - 1, blocksize)
used_cols = i * blocksize
n = min(blocksize, cols - used_cols)
T[1:n, (1:n) .+ used_cols] = UpperTriangular(view(x.T, 1:n, (1:n) .+ used_cols))
end
x_vec, back = to_vec([x.factors, T])
function QRCompact_from_vec(v)
factors, Tback = back(v)
return S(factors, Tback)
end
return x_vec, QRCompact_from_vec
end
# Non-array data structures
function to_vec(::Tuple{})
vec = Bool[]
function Tuple_from_vec(_)
return ()
end
return vec, Tuple_from_vec
end
function to_vec(x::Tuple)
x_vecs_and_backs = map(to_vec, x)
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
lengths = map(length, x_vecs)
sz = typeof(lengths)(cumsum(collect(lengths)))
function Tuple_from_vec(v)
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[s - l + 1:s])
end
end
return reduce(vcat, x_vecs), Tuple_from_vec
end
function to_vec(x::NamedTuple)
x_vec, back = to_vec(values(x))
function NamedTuple_from_vec(v)
v_vec_vec = back(v)
return typeof(x)(v_vec_vec)
end
return x_vec, NamedTuple_from_vec
end
# Convert to a vector-of-vectors to make use of existing functionality.
function to_vec(d::Dict)
d_vec, back = to_vec(collect(values(d)))
function Dict_from_vec(v)
v_vec_vec = back(v)
return Dict(key => v_vec_vec[n] for (n, key) in enumerate(keys(d)))
end
return d_vec, Dict_from_vec
end
# non-perturbable types
for T in (:DataType, :CartesianIndex, :AbstractZero)
T_from_vec = Symbol(T, :_from_vec)
@eval function FiniteDifferences.to_vec(x::$T)
function $T_from_vec(x_vec::Vector)
return x
end
return Bool[], $T_from_vec
end
end
# ChainRulesCore Differentials
function FiniteDifferences.to_vec(x::Tangent{P}) where{P}
x_canon = canonicalize(x) # to be safe, fill in every field and put in primal order.
x_inner = ChainRulesCore.backing(x_canon)
x_vec, back_inner = FiniteDifferences.to_vec(x_inner)
function Tangent_from_vec(y_vec)
y_back = back_inner(y_vec)
return Tangent{P, typeof(y_back)}(y_back)
end
return x_vec, Tangent_from_vec
end
function FiniteDifferences.to_vec(t::Thunk)
v, back = to_vec(unthunk(t))
Thunk_from_vec = v -> @thunk(back(v))
return v, Thunk_from_vec
end
function FiniteDifferences.to_vec(t::InplaceableThunk)
v, back = to_vec(unthunk(t))
function InplaceableThunk_from_vec(v)
return InplaceableThunk(
Δ -> Δ += back(b),
@thunk(back(v))
)
end
return v, InplaceableThunk_from_vec
end