Skip to content

Commit

Permalink
Merge pull request #44 from bionanoimaging/czt_cuda
Browse files Browse the repository at this point in the history
addec CUDA.jl support for czt
  • Loading branch information
roflmaostc authored Jul 21, 2024
2 parents a3f96fc + 271b6af commit ce67ccb
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 62 deletions.
139 changes: 82 additions & 57 deletions src/czt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export czt, iczt, plan_czt

"""
get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0)
get_kernel_1d(arr::AbstractArray{T,D}, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) where {T,D}
calculates the kernel for the Bluestein algorithm. Note that the length depends on the destination size.
Note the the resulting kernel-size is computed based on the minimum required length for the task.
Expand All @@ -21,12 +21,16 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El
returns: a tuple of three arrays for the initial multiplication (A*W), the convolution
(already fourier-transformed) and the post multiplication.
"""
function get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0)
# intorduce also sscale ??
# the size needed to avoid wrap
function get_kernel_1d(arr::AT, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) where {T,D, AT <: AbstractArray{T,D}}
# the size is needed to avoid wrap
RT = real(T)
CT = (RT <: Real) ? Complex{RT} : RT
RT = real(CT)
# nowrap_size = N + ceil(N÷2)

# converts ShiftedArrays.CircShiftedArray into a plain array type:
tmp = similar(arr, RT, (1,))
RAT = real_arr_type(typeof(tmp), Val(1))

# the maximal size where the convolution does not yield zero
# max_size = 2*N-1
# the minimum size needed for the convolution
Expand All @@ -35,13 +39,15 @@ function get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N)

#Note that the source size ssz is used here
# W of the source for FFTs.
n = (0:N-1)
n = RAT(0:N-1)
# pre-calculate the product of a.^(-n) and w to be later multiplied with the input x
# late casting is important, since the rounding errors are overly large if the original calculations are done in Float32.
aw = CT.((a .^ (-n)) .* w .^ ((n .^ 2) ./ 2))

conv_kernel = zeros(CT, L) # Array{CT}(undef, L)
m = (0:M-1)
conv_kernel = similar(arr, CT, L) # Array{CT}(undef, L)
fill!(conv_kernel, zero(CT))

m = RAT(0:M-1)
conv_kernel[1:M] .= w .^ (-(m .^ 2) ./ 2)
right_start = L-N+1
n = (1:N-1)
Expand All @@ -59,7 +65,7 @@ end

# type for planning. The arrays are 1D but oriented
"""
CZTPlan_1D{CT, D} # <: AbstractArray{T,D}
CZTPlan_1D{CT<:Complex, D<:Integer, AT<:AbstractArray{CT, D}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan}
type used for the onedimensional plan of the chirp Z transformation (CZT).
containing
Expand All @@ -70,20 +76,18 @@ containing
`aw`: factor to multiply input with
`fft_fv`: fourier-transform (FFTW) of the convolutio kernel
`wd`: factor to multiply the result of the convolution by
`fftw_plan`: plan for the forward FFTW of the convolution kernel
`ifftw_plan`: plan for the inverse FFTW of the convolution kernel
`fftw_plan!`: plan for the forward FFTW of the convolution kernel
`ifftw_plan!`: plan for the inverse FFTW of the convolution kernel
"""
struct CZTPlan_1D{CT, PT, D} # <: AbstractArray{T,D}
struct CZTPlan_1D{CT<:Complex, AT<:AbstractArray{CT}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan}
d :: Int
pad_value :: PT
pad_ranges :: NTuple{2,UnitRange{Int64}}
aw :: Array{CT, D}
fft_fv :: Array{CT, D}
wd :: Array{CT, D}
fftw_plan :: FFTW.cFFTWPlan
ifftw_plan :: AbstractFFTs.ScaledPlan
# dimension of this transformation
# as :: Array{T, D} # not needed since it is just the conjugate of ws
pad_ranges :: NTuple{2, UnitRange{Int64}}
aw :: AT
fft_fv :: AT
wd :: AT
fftw_plan! :: PFFT
ifftw_plan! :: PIFFT
end

"""
Expand All @@ -94,8 +98,8 @@ containing
# Members:
`plans`: vector of CZTPlan_1D for each of the directions of the ND array to transform
"""
struct CZTPlan_ND{CT, PT, D} # <: AbstractArray{T,D}
plans :: Vector{CZTPlan_1D{CT,PT, D}}
struct CZTPlan_ND{CT<:Complex, AT<:AbstractArray{CT}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan}
plans :: Vector{CZTPlan_1D{CT, AT, PT, PFFT, PIFFT}}
end

function get_invalid_ranges(sz, scaled, dsize, dst_center)
Expand Down Expand Up @@ -138,8 +142,8 @@ end
creates a plan for an one-dimensional chirp z-transformation (CZT). The generated plan is then applied via
muliplication. For details about the arguments, see `czt_1d()`.
"""
function plan_czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, extra_phase=nothing, global_phase=nothing, damp=1.0, src_center=(size(xin,d)+1)/2,
dst_center=dsize÷2+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)
function plan_czt_1d(xin::AT, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, extra_phase=nothing, global_phase=nothing, damp=1.0, src_center=(size(xin,d)+1)/2,
dst_center=dsize÷2+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) where {AT}

a = isnothing(a) ? exp(-1im*(dst_center-1)*2pi/(scaled*size(xin,d))) : a
w = isnothing(w) ? cispi(-2/(scaled*size(xin,d))) : w
Expand All @@ -148,23 +152,26 @@ function plan_czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, ex
extra_phase = isnothing(extra_phase) ? exp(1im*2pi*(src_center-1)/(scaled*size(xin,d))) : extra_phase
global_phase = isnothing(global_phase) ? a ^ (src_center-1) : global_phase

aw, fft_fv, wd = get_kernel_1d(eltype(xin), size(xin, d), dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase)
aw, fft_fv, wd = get_kernel_1d(xin, size(xin, d), dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase)

# set pad ranges to empty ranges:
start_range = 1:0
end_range = 1:0
stop_range = 1:0

if remove_wrap
start_range, stop_range = get_invalid_ranges(size(xin, d), scaled, dsize, dst_center)

wd[start_range] .= zero(eltype(wd))
wd[stop_range] .= zero(eltype(wd))
end

nsz = ntuple((dd) -> (d==dd) ? size(fft_fv, 1) : size(xin, dd), Val(ndims(xin)))
y = Array{eltype(aw), ndims(xin)}(undef, nsz)
y = similar(xin, eltype(aw), nsz)

fft_p = plan_fft(y, (d,); flags=fft_flags)
ifft_p = plan_ifft(y, (d,); flags=fft_flags) # inv(fft_p)
fft_p! = (typeof(y) <: Array) ? plan_fft!(y, (d,); flags=fft_flags) : plan_fft!(y, (d,))
ifft_p! = (typeof(y) <: Array) ? plan_ifft!(y, (d,); flags=fft_flags) : plan_ifft!(y, (d,))

plan = CZTPlan_1D(d, pad_value, (start_range, end_range), reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))), reorient(wd, d, Val(ndims(xin))), fft_p, ifft_p)
plan = CZTPlan_1D(d, pad_value, (start_range, stop_range), reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))), reorient(wd, d, Val(ndims(xin))), fft_p!, ifft_p!)
return plan
end

Expand All @@ -175,21 +182,30 @@ end
creates a plan for an N-dimensional chirp z-transformation (CZT). The generated plan is then applied via
muliplication. For details about the arguments, see `czt()`.
"""
function plan_czt(xin, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin)2 .+1, dst_center=dsize2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)
function plan_czt(xin::AbstractArray{U,D}, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)),
src_center=size(xin)2 .+1, dst_center=dsize2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) where {U,D}
CT = (eltype(xin) <: Real) ? Complex{eltype(xin)} : eltype(xin)
D = ndims(xin)
plans = [] # Vector{CZT1DPlan{CT,D}}
sz = size(xin)
for d in dims
xin = Array{eltype(xin)}(undef, sz)

d = dims[1]
p = plan_czt_1d(xin, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags)
plans = Vector{typeof(p)}(undef, length(dims))
sz = ntuple((dd)-> (dd==d) ? dsize[d] : sz[dd], ndims(xin))
n=1
plans[n]=p
n+=1
for d in dims[2:end]
xin = Array{eltype(xin)}(undef, sz)
p = plan_czt_1d(xin, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags)
sz = ntuple((dd)-> (dd==d) ? dsize[d] : sz[dd], ndims(xin))
push!(plans, p)
plans[n]=p
n += 1
end
return CZTPlan_ND{CT, typeof(pad_value),D}(plans)
return CZTPlan_ND(plans)
end

function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...) where {U,D} # Complex{U}
function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U),D} where {U,D}
xout = xin
for pd in p.plans
xout = czt_1d(xout, pd)
Expand Down Expand Up @@ -230,13 +246,13 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El
+ `remove_wrap`: if true, the positions that represent a wrap-around will be set to zero
+ `pad_value`: the value to pad wrapped data with.
"""
function czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, damp=1.0, src_center=size(xin,d)÷2+1,
dst_center=dsize÷2+1, extra_phase=nothing, global_phase=nothing, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)
function czt_1d(xin::AbstractArray{U,D}, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, damp=1.0, src_center=size(xin,d)÷2+1,
dst_center=dsize÷2+1, extra_phase=nothing, global_phase=nothing, remove_wrap=false, pad_value=zero(U), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(U), D} where {U,D}
plan = plan_czt_1d(xin, scaled, d, dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase, damp, src_center=src_center, dst_center=dst_center, remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags);
return plan * xin
end

function Base.:*(p::CZTPlan_1D, xin::AbstractArray{U,D}; kargs...) where {U,D} # Complex{U}
function Base.:*(p::CZTPlan_1D, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U), D} where {U,D} # Complex{U}
return czt_1d(xin, p)
end

Expand All @@ -258,7 +274,7 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El
# Arguments
`plan`: A plan created via plan_czt_1d()
"""
function czt_1d(xin, plan::CZTPlan_1D)
function czt_1d(xin::AbstractArray{U,D}, plan::CZTPlan_1D)::AbstractArray{complex(U), D} where {U,D}
# destination position
# cispi(-1/scaled * half_pix_shift)
#
Expand All @@ -267,25 +283,30 @@ function czt_1d(xin, plan::CZTPlan_1D)
# which (intentionally) leads to non-real results for even-sized arrays at non-unit zoom

L = size(plan.fft_fv, plan.d)
nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(ndims(xin)))
nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(D))
# append zeros
y = zeros(eltype(plan.aw), nsz)
myrange = ntuple((dd) -> (1:size(xin,dd)), Val(ndims(xin)))
y[myrange...] = xin .* plan.aw
# corner = ntuple((x)->1, Val(ndims(xin)))
# select_region(xin .* plan.aw, new_size=nsz, center=corner, dst_center=corner)

# g = ifft(fft(y, d) .* plan.fft_fv, d)
g = plan.ifftw_plan * (plan.fftw_plan * y .* plan.fft_fv)
tmp = eltype(plan.aw).(xin .* plan.aw)

corner = ntuple((x)->1, Val(D))
y = NDTools.select_region(tmp, nsz; center=corner, dst_center=corner)

# in-place application to y:
plan.fftw_plan! * y
y .*= plan.fft_fv
# in-place application to y:
plan.ifftw_plan! * y

# dsz = ntuple((dd) -> (d==dd) ? dsize : size(xin), Val(ndims(xin)))
# return only the wanted (valid) part
myrange = ntuple((dd) -> (dd==plan.d) ? (1:size(plan.wd,plan.d)) : (1:size(xin, dd)), Val(ndims(xin)))
res = g[myrange...] .* plan.wd
myrange = ntuple((dd) -> (dd==plan.d) ? (1:size(plan.wd, plan.d)) : (1:size(xin, dd)), Val(D))
res = y[myrange...] .* plan.wd
# pad_value=0 means that it is either already handled by plan.wd or no padding is wanted.
if plan.pad_value != 0
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[1] : Colon(), Val(ndims(xin)))
# first the start_range (plan.pad_ranges[1]):
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[1] : Colon(), Val(D))
res[myrange...] .= plan.pad_value
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[2] : Colon(), Val(ndims(xin)))
# first the stop_range (plan.pad_ranges[2]):
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[2] : Colon(), Val(D))
res[myrange...] .= plan.pad_value
end
return res
Expand Down Expand Up @@ -355,18 +376,22 @@ julia> zoomed = real.(ift(xft))
0.0239759 -0.028264 0.0541186 -0.0116475 -0.261294 0.312719 -0.261294 -0.0116475 0.0541186 -0.028264
```
"""
function czt(xin::AbstractArray{T,N}, scale, dims=1:ndims(xin), dsize=size(xin);
a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin)2 .+1, dst_center=dsize2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(T),N} where {T,N}
function czt(xin::AbstractArray{T,D}, scale, dims=1:D, dsize=size(xin);
a=nothing, w=nothing, damp=ones(D), src_center=size(xin)2 .+1, dst_center=dsize2 .+1,
remove_wrap=false, pad_value=zero(T), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(T), D} where {T,D}
xout = xin
if length(scale) != ndims(xin)
error("Every of the $(ndims(xin)) dimension needs exactly one corresponding scale (zoom) factor, which should be equal to 1.0 for dimensions not contained in the dims argument.")
end
for d = 1:ndims(xin)
# check all the dims:
for d = 1:D
if !(d in dims) && scale[d] != 1.0 && !isnothing(scale[d])
error("The scale factor $(scale[d]) needs to be nothing or 1.0, if this dimension is not in the list of dimensions to transform.")
end
end

for d in dims
# in-place assignement is not possible, since with a zoom the size always changes.
xout = czt_1d(xout, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags)
end
return xout
Expand Down
10 changes: 5 additions & 5 deletions test/czt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ using NDTools # this is needed for the select_region! function below.
# @test ≈(czt(y,zoom, src_center=(size(y).+1)./2), select_region(upsample2(ft(y), fix_center=true), new_size=size(y)), rtol=1e-5)

# for uneven sizes this works:
@test (czt(y[1:5,1:5],zoom, (1,2), (10,10)), upsample2(ft(y[1:5,1:5]), fix_center=true), rtol=1e-5)
@test (czt(y[1:5,1:5], zoom, (1,2), (10,10)), upsample2(ft(y[1:5,1:5]), fix_center=true), rtol=1e-5)
p_czt = plan_czt(y, zoom, (1,2), (11,12))
@test (p_czt * y, czt(y, zoom, (1,2), (11,12)))
# zoom smaller 1.0 causes wrap around:
zoom = (0.5,2.0)
@test abs(czt(y,zoom)[1,1]) > 1e-5
zoom = (2.0, 0.5)
# check if the remove_wrap works
@test abs(czt(y,zoom; remove_wrap=true)[1,1]) == 0.0
@test abs(iczt(y,zoom; remove_wrap=true)[1,1]) == 0.0
@test abs(czt(y,zoom; pad_value=0.2, remove_wrap=true)[1,1]) == 0.2f0
@test abs(iczt(y,zoom; pad_value=0.5f0, remove_wrap=true)[1,1]) == 0.5f0
@test abs(czt(y, zoom; remove_wrap=true)[1,1]) == 0.0
@test abs(iczt(y, zoom; remove_wrap=true)[1,1]) == 0.0
@test abs(czt(y, zoom; pad_value=0.2, remove_wrap=true)[1,1]) == 0.2f0
@test abs(iczt(y, zoom; pad_value=0.5f0, remove_wrap=true)[1,1]) == 0.5f0
end
end

0 comments on commit ce67ccb

Please sign in to comment.