From 0a29b0b12dcec7abbe019d3b73cd0c1ecf01e346 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Sat, 13 Jul 2024 20:10:37 +0200 Subject: [PATCH 1/3] addec CUDA.jl support for czt --- src/czt.jl | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/czt.jl b/src/czt.jl index 5874ee7..88ed28f 100644 --- a/src/czt.jl +++ b/src/czt.jl @@ -1,7 +1,7 @@ export czt, iczt, plan_czt """ - get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) + get_kernel_1d(arr::AbstractArray{T,D}, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) where {T,D} calculates the kernel for the Bluestein algorithm. Note that the length depends on the destination size. Note the the resulting kernel-size is computed based on the minimum required length for the task. @@ -21,11 +21,15 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El returns: a tuple of three arrays for the initial multiplication (A*W), the convolution (already fourier-transformed) and the post multiplication. """ -function get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) +function get_kernel_1d(arr::AT, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) where {T,D, AT <: AbstractArray{T,D}} # intorduce also sscale ?? # the size needed to avoid wrap + RT = real(T) CT = (RT <: Real) ? Complex{RT} : RT RT = real(CT) + + tmp = similar(arr, RT, (1,)) # converts ShiftedArrays.CircShiftedArray into a plain array type + RAT = real_arr_type(typeof(tmp), Val(1)) # nowrap_size = N + ceil(N÷2) # the maximal size where the convolution does not yield zero # max_size = 2*N-1 @@ -35,13 +39,15 @@ function get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N) #Note that the source size ssz is used here # W of the source for FFTs. - n = (0:N-1) + n = RAT(0:N-1) # pre-calculate the product of a.^(-n) and w to be later multiplied with the input x # late casting is important, since the rounding errors are overly large if the original calculations are done in Float32. aw = CT.((a .^ (-n)) .* w .^ ((n .^ 2) ./ 2)) - conv_kernel = zeros(CT, L) # Array{CT}(undef, L) - m = (0:M-1) + conv_kernel = similar(arr, CT, L) # Array{CT}(undef, L) + fill!(conv_kernel, zero(CT)) + + m = RAT(0:M-1) conv_kernel[1:M] .= w .^ (-(m .^ 2) ./ 2) right_start = L-N+1 n = (1:N-1) @@ -77,10 +83,10 @@ struct CZTPlan_1D{CT, PT, D} # <: AbstractArray{T,D} d :: Int pad_value :: PT pad_ranges :: NTuple{2,UnitRange{Int64}} - aw :: Array{CT, D} - fft_fv :: Array{CT, D} - wd :: Array{CT, D} - fftw_plan :: FFTW.cFFTWPlan + aw :: AbstractArray{CT, D} + fft_fv :: AbstractArray{CT, D} + wd :: AbstractArray{CT, D} + fftw_plan :: AbstractFFTs.Plan ifftw_plan :: AbstractFFTs.ScaledPlan # dimension of this transformation # as :: Array{T, D} # not needed since it is just the conjugate of ws @@ -138,8 +144,8 @@ end creates a plan for an one-dimensional chirp z-transformation (CZT). The generated plan is then applied via muliplication. For details about the arguments, see `czt_1d()`. """ -function plan_czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, extra_phase=nothing, global_phase=nothing, damp=1.0, src_center=(size(xin,d)+1)/2, - dst_center=dsize÷2+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) +function plan_czt_1d(xin::AT, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, extra_phase=nothing, global_phase=nothing, damp=1.0, src_center=(size(xin,d)+1)/2, + dst_center=dsize÷2+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) where {AT} a = isnothing(a) ? exp(-1im*(dst_center-1)*2pi/(scaled*size(xin,d))) : a w = isnothing(w) ? cispi(-2/(scaled*size(xin,d))) : w @@ -148,7 +154,7 @@ function plan_czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, ex extra_phase = isnothing(extra_phase) ? exp(1im*2pi*(src_center-1)/(scaled*size(xin,d))) : extra_phase global_phase = isnothing(global_phase) ? a ^ (src_center-1) : global_phase - aw, fft_fv, wd = get_kernel_1d(eltype(xin), size(xin, d), dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase) + aw, fft_fv, wd = get_kernel_1d(xin, size(xin, d), dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase) start_range = 1:0 end_range = 1:0 @@ -159,10 +165,10 @@ function plan_czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, ex end nsz = ntuple((dd) -> (d==dd) ? size(fft_fv, 1) : size(xin, dd), Val(ndims(xin))) - y = Array{eltype(aw), ndims(xin)}(undef, nsz) + y = similar(xin, eltype(aw), nsz) # Array{eltype(aw), ndims(xin)}(undef, nsz) - fft_p = plan_fft(y, (d,); flags=fft_flags) - ifft_p = plan_ifft(y, (d,); flags=fft_flags) # inv(fft_p) + fft_p = (typeof(y) <: Array) ? plan_fft(y, (d,); flags=fft_flags) : plan_fft(y, (d,)) + ifft_p = (typeof(y) <: Array) ? plan_ifft(y, (d,); flags=fft_flags) : plan_ifft(y, (d,)) # inv(fft_p) plan = CZTPlan_1D(d, pad_value, (start_range, end_range), reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))), reorient(wd, d, Val(ndims(xin))), fft_p, ifft_p) return plan @@ -269,9 +275,13 @@ function czt_1d(xin, plan::CZTPlan_1D) L = size(plan.fft_fv, plan.d) nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(ndims(xin))) # append zeros - y = zeros(eltype(plan.aw), nsz) - myrange = ntuple((dd) -> (1:size(xin,dd)), Val(ndims(xin))) - y[myrange...] = xin .* plan.aw + tmp = eltype(plan.aw).(xin .* plan.aw) + y = NDTools.select_region(tmp, nsz; center=size(tmp).÷2 .+1, dst_center=size(tmp).÷2 .+1) + + # y = zeros(eltype(plan.aw), nsz) + # myrange = ntuple((dd) -> (1:size(xin,dd)), Val(ndims(xin))) + # y[myrange...] = xin .* plan.aw + # corner = ntuple((x)->1, Val(ndims(xin))) # select_region(xin .* plan.aw, new_size=nsz, center=corner, dst_center=corner) From c68de0cb8d285406475564ebd9d1965c9ef263ff Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Sun, 14 Jul 2024 14:36:24 +0200 Subject: [PATCH 2/3] merged with previous PR --- src/czt.jl | 84 +++++++++++++++++++++++++++++------------------------ test/czt.jl | 10 +++---- 2 files changed, 51 insertions(+), 43 deletions(-) diff --git a/src/czt.jl b/src/czt.jl index 88ed28f..6207fd7 100644 --- a/src/czt.jl +++ b/src/czt.jl @@ -22,15 +22,15 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El (already fourier-transformed) and the post multiplication. """ function get_kernel_1d(arr::AT, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) where {T,D, AT <: AbstractArray{T,D}} - # intorduce also sscale ?? - # the size needed to avoid wrap + # the size is needed to avoid wrap RT = real(T) CT = (RT <: Real) ? Complex{RT} : RT RT = real(CT) - tmp = similar(arr, RT, (1,)) # converts ShiftedArrays.CircShiftedArray into a plain array type + # converts ShiftedArrays.CircShiftedArray into a plain array type: + tmp = similar(arr, RT, (1,)) RAT = real_arr_type(typeof(tmp), Val(1)) - # nowrap_size = N + ceil(N÷2) + # the maximal size where the convolution does not yield zero # max_size = 2*N-1 # the minimum size needed for the convolution @@ -76,8 +76,8 @@ containing `aw`: factor to multiply input with `fft_fv`: fourier-transform (FFTW) of the convolutio kernel `wd`: factor to multiply the result of the convolution by - `fftw_plan`: plan for the forward FFTW of the convolution kernel - `ifftw_plan`: plan for the inverse FFTW of the convolution kernel + `fftw_plan!`: plan for the forward FFTW of the convolution kernel + `ifftw_plan!`: plan for the inverse FFTW of the convolution kernel """ struct CZTPlan_1D{CT, PT, D} # <: AbstractArray{T,D} d :: Int @@ -86,8 +86,8 @@ struct CZTPlan_1D{CT, PT, D} # <: AbstractArray{T,D} aw :: AbstractArray{CT, D} fft_fv :: AbstractArray{CT, D} wd :: AbstractArray{CT, D} - fftw_plan :: AbstractFFTs.Plan - ifftw_plan :: AbstractFFTs.ScaledPlan + fftw_plan! :: AbstractFFTs.Plan + ifftw_plan! :: AbstractFFTs.ScaledPlan # dimension of this transformation # as :: Array{T, D} # not needed since it is just the conjugate of ws end @@ -156,21 +156,24 @@ function plan_czt_1d(xin::AT, scaled, d, dsize=size(xin,d); a=nothing, w=nothing aw, fft_fv, wd = get_kernel_1d(xin, size(xin, d), dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase) + # set pad ranges to empty ranges: start_range = 1:0 - end_range = 1:0 + stop_range = 1:0 + if remove_wrap start_range, stop_range = get_invalid_ranges(size(xin, d), scaled, dsize, dst_center) + wd[start_range] .= zero(eltype(wd)) wd[stop_range] .= zero(eltype(wd)) end nsz = ntuple((dd) -> (d==dd) ? size(fft_fv, 1) : size(xin, dd), Val(ndims(xin))) - y = similar(xin, eltype(aw), nsz) # Array{eltype(aw), ndims(xin)}(undef, nsz) + y = similar(xin, eltype(aw), nsz) - fft_p = (typeof(y) <: Array) ? plan_fft(y, (d,); flags=fft_flags) : plan_fft(y, (d,)) - ifft_p = (typeof(y) <: Array) ? plan_ifft(y, (d,); flags=fft_flags) : plan_ifft(y, (d,)) # inv(fft_p) + fft_p! = (typeof(y) <: Array) ? plan_fft!(y, (d,); flags=fft_flags) : plan_fft!(y, (d,)) + ifft_p! = (typeof(y) <: Array) ? plan_ifft!(y, (d,); flags=fft_flags) : plan_ifft!(y, (d,)) - plan = CZTPlan_1D(d, pad_value, (start_range, end_range), reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))), reorient(wd, d, Val(ndims(xin))), fft_p, ifft_p) + plan = CZTPlan_1D(d, pad_value, (start_range, stop_range), reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))), reorient(wd, d, Val(ndims(xin))), fft_p!, ifft_p!) return plan end @@ -181,9 +184,9 @@ end creates a plan for an N-dimensional chirp z-transformation (CZT). The generated plan is then applied via muliplication. For details about the arguments, see `czt()`. """ -function plan_czt(xin, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) +function plan_czt(xin::AbstractArray{U,D}, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)), + src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) where {U,D} CT = (eltype(xin) <: Real) ? Complex{eltype(xin)} : eltype(xin) - D = ndims(xin) plans = [] # Vector{CZT1DPlan{CT,D}} sz = size(xin) for d in dims @@ -192,10 +195,10 @@ function plan_czt(xin, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp= sz = ntuple((dd)-> (dd==d) ? dsize[d] : sz[dd], ndims(xin)) push!(plans, p) end - return CZTPlan_ND{CT, typeof(pad_value),D}(plans) + return CZTPlan_ND{CT, typeof(pad_value), D}(plans) end -function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...) where {U,D} # Complex{U} +function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U),D} where {U,D} xout = xin for pd in p.plans xout = czt_1d(xout, pd) @@ -236,13 +239,13 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El + `remove_wrap`: if true, the positions that represent a wrap-around will be set to zero + `pad_value`: the value to pad wrapped data with. """ -function czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, damp=1.0, src_center=size(xin,d)÷2+1, - dst_center=dsize÷2+1, extra_phase=nothing, global_phase=nothing, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) +function czt_1d(xin::AbstractArray{U,D}, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, damp=1.0, src_center=size(xin,d)÷2+1, + dst_center=dsize÷2+1, extra_phase=nothing, global_phase=nothing, remove_wrap=false, pad_value=zero(U), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(U), D} where {U,D} plan = plan_czt_1d(xin, scaled, d, dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase, damp, src_center=src_center, dst_center=dst_center, remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags); return plan * xin end -function Base.:*(p::CZTPlan_1D, xin::AbstractArray{U,D}; kargs...) where {U,D} # Complex{U} +function Base.:*(p::CZTPlan_1D, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U), D} where {U,D} # Complex{U} return czt_1d(xin, p) end @@ -264,7 +267,7 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El # Arguments `plan`: A plan created via plan_czt_1d() """ -function czt_1d(xin, plan::CZTPlan_1D) +function czt_1d(xin::AbstractArray{U,D}, plan::CZTPlan_1D)::AbstractArray{complex(U), D} where {U,D} # destination position # cispi(-1/scaled * half_pix_shift) # @@ -273,29 +276,30 @@ function czt_1d(xin, plan::CZTPlan_1D) # which (intentionally) leads to non-real results for even-sized arrays at non-unit zoom L = size(plan.fft_fv, plan.d) - nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(ndims(xin))) + nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(D)) # append zeros tmp = eltype(plan.aw).(xin .* plan.aw) - y = NDTools.select_region(tmp, nsz; center=size(tmp).÷2 .+1, dst_center=size(tmp).÷2 .+1) - - # y = zeros(eltype(plan.aw), nsz) - # myrange = ntuple((dd) -> (1:size(xin,dd)), Val(ndims(xin))) - # y[myrange...] = xin .* plan.aw + + corner = ntuple((x)->1, Val(D)) + y = NDTools.select_region(tmp, nsz; center=corner, dst_center=corner) - # corner = ntuple((x)->1, Val(ndims(xin))) - # select_region(xin .* plan.aw, new_size=nsz, center=corner, dst_center=corner) + # in-place application to y: + plan.fftw_plan! * y + y .*= plan.fft_fv + # in-place application to y: + plan.ifftw_plan! * y - # g = ifft(fft(y, d) .* plan.fft_fv, d) - g = plan.ifftw_plan * (plan.fftw_plan * y .* plan.fft_fv) # dsz = ntuple((dd) -> (d==dd) ? dsize : size(xin), Val(ndims(xin))) # return only the wanted (valid) part - myrange = ntuple((dd) -> (dd==plan.d) ? (1:size(plan.wd,plan.d)) : (1:size(xin, dd)), Val(ndims(xin))) - res = g[myrange...] .* plan.wd + myrange = ntuple((dd) -> (dd==plan.d) ? (1:size(plan.wd, plan.d)) : (1:size(xin, dd)), Val(D)) + res = y[myrange...] .* plan.wd # pad_value=0 means that it is either already handled by plan.wd or no padding is wanted. if plan.pad_value != 0 - myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[1] : Colon(), Val(ndims(xin))) + # first the start_range (plan.pad_ranges[1]): + myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[1] : Colon(), Val(D)) res[myrange...] .= plan.pad_value - myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[2] : Colon(), Val(ndims(xin))) + # first the stop_range (plan.pad_ranges[2]): + myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[2] : Colon(), Val(D)) res[myrange...] .= plan.pad_value end return res @@ -365,18 +369,22 @@ julia> zoomed = real.(ift(xft)) 0.0239759 -0.028264 0.0541186 -0.0116475 -0.261294 0.312719 -0.261294 -0.0116475 0.0541186 -0.028264 ``` """ -function czt(xin::AbstractArray{T,N}, scale, dims=1:ndims(xin), dsize=size(xin); - a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(T),N} where {T,N} +function czt(xin::AbstractArray{T,D}, scale, dims=1:D, dsize=size(xin); + a=nothing, w=nothing, damp=ones(D), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, + remove_wrap=false, pad_value=zero(T), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(T), D} where {T,D} xout = xin if length(scale) != ndims(xin) error("Every of the $(ndims(xin)) dimension needs exactly one corresponding scale (zoom) factor, which should be equal to 1.0 for dimensions not contained in the dims argument.") end - for d = 1:ndims(xin) + # check all the dims: + for d = 1:D if !(d in dims) && scale[d] != 1.0 && !isnothing(scale[d]) error("The scale factor $(scale[d]) needs to be nothing or 1.0, if this dimension is not in the list of dimensions to transform.") end end + for d in dims + # in-place assignement is not possible, since with a zoom the size always changes. xout = czt_1d(xout, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags) end return xout diff --git a/test/czt.jl b/test/czt.jl index 9836d07..3ec04bf 100644 --- a/test/czt.jl +++ b/test/czt.jl @@ -23,7 +23,7 @@ using NDTools # this is needed for the select_region! function below. # @test ≈(czt(y,zoom, src_center=(size(y).+1)./2), select_region(upsample2(ft(y), fix_center=true), new_size=size(y)), rtol=1e-5) # for uneven sizes this works: - @test ≈(czt(y[1:5,1:5],zoom, (1,2), (10,10)), upsample2(ft(y[1:5,1:5]), fix_center=true), rtol=1e-5) + @test ≈(czt(y[1:5,1:5], zoom, (1,2), (10,10)), upsample2(ft(y[1:5,1:5]), fix_center=true), rtol=1e-5) p_czt = plan_czt(y, zoom, (1,2), (11,12)) @test ≈(p_czt * y, czt(y, zoom, (1,2), (11,12))) # zoom smaller 1.0 causes wrap around: @@ -31,9 +31,9 @@ using NDTools # this is needed for the select_region! function below. @test abs(czt(y,zoom)[1,1]) > 1e-5 zoom = (2.0, 0.5) # check if the remove_wrap works - @test abs(czt(y,zoom; remove_wrap=true)[1,1]) == 0.0 - @test abs(iczt(y,zoom; remove_wrap=true)[1,1]) == 0.0 - @test abs(czt(y,zoom; pad_value=0.2, remove_wrap=true)[1,1]) == 0.2f0 - @test abs(iczt(y,zoom; pad_value=0.5f0, remove_wrap=true)[1,1]) == 0.5f0 + @test abs(czt(y, zoom; remove_wrap=true)[1,1]) == 0.0 + @test abs(iczt(y, zoom; remove_wrap=true)[1,1]) == 0.0 + @test abs(czt(y, zoom; pad_value=0.2, remove_wrap=true)[1,1]) == 0.2f0 + @test abs(iczt(y, zoom; pad_value=0.5f0, remove_wrap=true)[1,1]) == 0.5f0 end end From 271b6af46d189838fbfc9f3e3e633925d11359a0 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Mon, 15 Jul 2024 18:50:24 +0200 Subject: [PATCH 3/3] type improvements for CZTPlan_1D and CZTPlan_ND --- src/czt.jl | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/czt.jl b/src/czt.jl index 6207fd7..91ff20c 100644 --- a/src/czt.jl +++ b/src/czt.jl @@ -65,7 +65,7 @@ end # type for planning. The arrays are 1D but oriented """ - CZTPlan_1D{CT, D} # <: AbstractArray{T,D} + CZTPlan_1D{CT<:Complex, D<:Integer, AT<:AbstractArray{CT, D}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan} type used for the onedimensional plan of the chirp Z transformation (CZT). containing @@ -79,17 +79,15 @@ containing `fftw_plan!`: plan for the forward FFTW of the convolution kernel `ifftw_plan!`: plan for the inverse FFTW of the convolution kernel """ -struct CZTPlan_1D{CT, PT, D} # <: AbstractArray{T,D} +struct CZTPlan_1D{CT<:Complex, AT<:AbstractArray{CT}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan} d :: Int pad_value :: PT - pad_ranges :: NTuple{2,UnitRange{Int64}} - aw :: AbstractArray{CT, D} - fft_fv :: AbstractArray{CT, D} - wd :: AbstractArray{CT, D} - fftw_plan! :: AbstractFFTs.Plan - ifftw_plan! :: AbstractFFTs.ScaledPlan - # dimension of this transformation - # as :: Array{T, D} # not needed since it is just the conjugate of ws + pad_ranges :: NTuple{2, UnitRange{Int64}} + aw :: AT + fft_fv :: AT + wd :: AT + fftw_plan! :: PFFT + ifftw_plan! :: PIFFT end """ @@ -100,8 +98,8 @@ containing # Members: `plans`: vector of CZTPlan_1D for each of the directions of the ND array to transform """ -struct CZTPlan_ND{CT, PT, D} # <: AbstractArray{T,D} - plans :: Vector{CZTPlan_1D{CT,PT, D}} +struct CZTPlan_ND{CT<:Complex, AT<:AbstractArray{CT}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan} + plans :: Vector{CZTPlan_1D{CT, AT, PT, PFFT, PIFFT}} end function get_invalid_ranges(sz, scaled, dsize, dst_center) @@ -187,15 +185,24 @@ muliplication. For details about the arguments, see `czt()`. function plan_czt(xin::AbstractArray{U,D}, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) where {U,D} CT = (eltype(xin) <: Real) ? Complex{eltype(xin)} : eltype(xin) - plans = [] # Vector{CZT1DPlan{CT,D}} sz = size(xin) - for d in dims + xin = Array{eltype(xin)}(undef, sz) + + d = dims[1] + p = plan_czt_1d(xin, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags) + plans = Vector{typeof(p)}(undef, length(dims)) + sz = ntuple((dd)-> (dd==d) ? dsize[d] : sz[dd], ndims(xin)) + n=1 + plans[n]=p + n+=1 + for d in dims[2:end] xin = Array{eltype(xin)}(undef, sz) p = plan_czt_1d(xin, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags) sz = ntuple((dd)-> (dd==d) ? dsize[d] : sz[dd], ndims(xin)) - push!(plans, p) + plans[n]=p + n += 1 end - return CZTPlan_ND{CT, typeof(pad_value), D}(plans) + return CZTPlan_ND(plans) end function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U),D} where {U,D}