Skip to content

Commit 2bb9067

Browse files
committed
add view transformations
1 parent 4795704 commit 2bb9067

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

src/aggregation.jl

+47
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,53 @@ function _domain_label(transformation::ArrayTransformation, index::Int)
106106
_array_domain_label(inner_transformation, dims, index)
107107
end
108108

109+
####
110+
#### array view
111+
####
112+
113+
"""
114+
$(TYPEDEF)
115+
116+
View of an array with `dims`.
117+
118+
!!! note
119+
This feature is experimental, and not part of the stable API; it may disappear or change without
120+
relevant changes in SemVer or deprecations. Inner transformations are not supported.
121+
"""
122+
struct ViewTransformation{M} <: VectorTransform
123+
dims::NTuple{M, Int}
124+
end
125+
126+
function as(::typeof(view), dims::Tuple{Vararg{Int}})
127+
@argcheck all(d -> d 0, dims) "All dimensions need to be non-negative."
128+
ViewTransformation(dims)
129+
end
130+
131+
as(::typeof(view), dims::Int...) = as(view, dims)
132+
133+
dimension(transformation::ViewTransformation) = prod(transformation.dims)
134+
135+
function transform_with(flag::LogJacFlag, t::ViewTransformation, x, index)
136+
index′ = index + dimension(t)
137+
y = reshape(@view(x[index:(index′-1)]), t.dims)
138+
y, logjac_zero(flag, robust_eltype(x)), index′
139+
end
140+
141+
function _domain_label(transformation::ViewTransformation, index::Int)
142+
@unpack dims = transformation
143+
_array_domain_label(asℝ, dims, index)
144+
end
145+
146+
inverse_eltype(transformation::ViewTransformation, y) = eltype(y)
147+
148+
function inverse_at!(x::AbstractVector, index, transformation::ViewTransformation,
149+
y::AbstractArray)
150+
@argcheck size(y) == transformation.dims
151+
index′ = index + dimension(transformation)
152+
copy!(@view(x[index:(index′-1)]), vec(y))
153+
index′
154+
end
155+
109156
####
110157
#### static array
111158
####

test/runtests.jl

+14
Original file line numberDiff line numberDiff line change
@@ -644,3 +644,17 @@ end
644644
@testset "static arrays inference" begin
645645
@test @inferred transform_with(NOLOGJAC, as(SVector{3, Float64}), zeros(3), 1) == (SVector(0.0, 0.0, 0.0), NOLOGJAC, 4)
646646
end
647+
648+
@testset "view transformations" begin
649+
x = randn(10)
650+
t = as((a = asℝ, b = as(view, 2, 4), c = asℝ))
651+
y, lj = transform_and_logjac(t, x)
652+
@test typeof(y.b) <: AbstractMatrix
653+
@test size(y.b) == (2, 4)
654+
# test inverse
655+
@test inverse(t, y) == x
656+
# test that it is a view
657+
z = y.b[3]
658+
y.b[3] += 1
659+
@test x[4] == z + 1
660+
end

0 commit comments

Comments
 (0)