Skip to content

Commit

Permalink
reduce allocations in dims_howmany (#269)
Browse files Browse the repository at this point in the history
* reduce allocations in dims_howmany

* Update src/fft.jl

Co-authored-by: Steven G. Johnson <stevenj@mit.edu>

* Dont collect size tuple

* filter for Int/Tuple regions

* use tuple instead of vector region at more places

* remove unused methods

* test region collections

* bump version to v1.7.0

---------

Co-authored-by: Steven G. Johnson <stevenj@mit.edu>
  • Loading branch information
jishnub and stevengj authored Apr 12, 2023
1 parent 82a99dc commit e4a00b1
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FFTW"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.6.1"
version = "1.7.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
87 changes: 62 additions & 25 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -566,20 +566,51 @@ unsafe_execute!(plan::r2rFFTWPlan{T},
# re-use the table of trigonometric constants from the first plan.

# Compute dims and howmany for FFTW guru planner
function dims_howmany(X::StridedArray, Y::StridedArray,
sz::Vector{Int}, region)
reg = Int[region...]::Vector{Int}
if length(unique(reg)) < length(reg)
_anyrepeated(::Union{Number, AbstractUnitRange}) = false
function _anyrepeated(region)
any(region) do x
count(==(x), region) > 1
end
end

# Utility methods to reduce allocations in dims_howmany
@inline _setindex(oreg, v, n) = (oreg[n] = v; oreg)
@inline _setindex(oreg::Tuple, v, n) = Base.setindex(oreg, v, n)
@inline _filtercoll(region::Union{Int, Tuple}, len) = ntuple(zero, len)
@inline _filtercoll(region, len) = Vector{Int}(undef, len)
# Optimized filter(∉(region), 1:ndims(X))
function _filter_notin_region(region, ::Val{ndimsX}) where {ndimsX}
oreg = _filtercoll(region, ndimsX - length(region))
n = 1
for dim in 1:ndimsX
dim in region && continue
oreg = _setindex(oreg, dim, n)
n += 1
end
oreg
end
function dims_howmany(X::StridedArray, Y::StridedArray, sz, region)
if _anyrepeated(region)
throw(ArgumentError("each dimension can be transformed at most once"))
end
ist = [strides(X)...]
ost = [strides(Y)...]
dims = Matrix(transpose([sz[reg] ist[reg] ost[reg]]))
oreg = [1:ndims(X);]
oreg[reg] .= 0
oreg = filter(d -> d > 0, oreg)
howmany = Matrix(transpose([sz[oreg] ist[oreg] ost[oreg]]))
return (dims, howmany)
ist = strides(X)
ost = strides(Y)
dims = Matrix{Int}(undef, 3, length(region))
for (ind, i) in enumerate(region)
dims[1, ind] = sz[i]
dims[2, ind] = ist[i]
dims[3, ind] = ost[i]
end

oreg = _filter_notin_region(region, Val(ndims(X)))
howmany = Matrix{Int}(undef, 3, length(oreg))
for (ind, i) in enumerate(oreg)
howmany[1, ind] = sz[i]
howmany[2, ind] = ist[i]
howmany[3, ind] = ost[i]
end

return dims, howmany
end

function fix_kinds(region, kinds)
Expand All @@ -604,6 +635,10 @@ function fix_kinds(region, kinds)
return k
end

_circshiftmin1(v) = circshift(collect(Int, v), -1)
_circshiftmin1(t::Tuple) = (t[2:end]..., t[1])
_circshiftmin1(x::Integer) = x

# low-level FFTWPlan creation (for internal use in FFTW module)
for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
(:Float32,:(Complex{Float32}),"fftwf",:libfftw3f))
Expand All @@ -613,7 +648,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
direction = K
unsafe_set_timelimit($Tr, timelimit)
R = isa(region, Tuple) ? region : copy(region)
dims, howmany = dims_howmany(X, Y, [size(X)...], R)
dims, howmany = dims_howmany(X, Y, size(X), R)
plan = ccall(($(string(fftw,"_plan_guru64_dft")),$lib[]),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Expand All @@ -631,9 +666,9 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
Y::StridedArray{$Tc,N},
region, flags::Integer, timelimit::Real) where {inplace,N}
R = isa(region, Tuple) ? region : copy(region)
region = circshift(Int[region...],-1) # FFTW halves last dim
regionshft = _circshiftmin1(region) # FFTW halves last dim
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
dims, howmany = dims_howmany(X, Y, size(X), regionshft)
plan = ccall(($(string(fftw,"_plan_guru64_dft_r2c")),$lib[]),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Expand All @@ -651,9 +686,9 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
Y::StridedArray{$Tr,N},
region, flags::Integer, timelimit::Real) where {inplace,N}
R = isa(region, Tuple) ? region : copy(region)
region = circshift(Int[region...],-1) # FFTW halves last dim
regionshft = _circshiftmin1(region) # FFTW halves last dim
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(Y)...], region)
dims, howmany = dims_howmany(X, Y, size(Y), regionshft)
plan = ccall(($(string(fftw,"_plan_guru64_dft_c2r")),$lib[]),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Expand All @@ -675,7 +710,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
R = isa(region, Tuple) ? region : copy(region)
knd = fix_kinds(region, kinds)
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
dims, howmany = dims_howmany(X, Y, size(X), region)
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Expand All @@ -698,9 +733,11 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
R = isa(region, Tuple) ? region : copy(region)
knd = fix_kinds(region, kinds)
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
dims[2:3, 1:size(dims,2)] *= 2
howmany[2:3, 1:size(howmany,2)] *= 2
dims, howmany = dims_howmany(X, Y, size(X), region)
@views begin
dims[2:3, :] .*= 2
howmany[2:3, :] .*= 2
end
howmany = [howmany [2,1,1]] # append loop over real/imag parts
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]),
PlanPtr,
Expand Down Expand Up @@ -759,9 +796,9 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD))
cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit)
end
$plan_f(X::StridedArray{<:fftwComplex}; kws...) =
$plan_f(X, 1:ndims(X); kws...)
$plan_f(X, ntuple(identity, ndims(X)); kws...)
$plan_f!(X::StridedArray{<:fftwComplex}; kws...) =
$plan_f!(X, 1:ndims(X); kws...)
$plan_f!(X, ntuple(identity, ndims(X)); kws...)

function plan_inv(p::cFFTWPlan{T,$direction,inplace,N};
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace}
Expand Down Expand Up @@ -845,8 +882,8 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
end
end

plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...)
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...)
plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,ntuple(identity, ndims(X));kws...)
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,ntuple(identity, ndims(X));kws...)

function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N},
num_threads::Union{Nothing, Integer} = nothing) where N
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ true_fftd3_m3d[:,:,2] .= -15
end

@testset "rfft/rfftn" begin
# Test regions as int/collection
@test rfft(m4,1) == rfft(m4,1:1) == rfft(m4,(1,)) == rfft(m4, [1])

rfft_m4 = rfft(m4,1)
rfftd2_m4 = rfft(m4,2)
rfftn_m4 = rfft(m4)
Expand Down

6 comments on commit e4a00b1

@jishnub
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stevengj could you tag a release? Thanks!

@stevengj
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it should be a patch release, since it looks like there are no new features or API additions/changes?

@jishnub
Copy link
Contributor Author

@jishnub jishnub commented on e4a00b1 Apr 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it a minor version bump, as in 82a99dc, the type parameters of r2rFFTWPlan were changed. This will be mildly breaking, e.g. in FastTransforms.jl. This shouldn't break Hadamard.jl, as the type parameter is now uniformly a Vector{Int32}, as in https://github.com/JuliaMath/Hadamard.jl/pull/18/files. Ideally, the version bump should have been in 82a99dc, but the version was not bumped correctly in that commit. It's fixed here.

@jishnub
Copy link
Contributor Author

@jishnub jishnub commented on e4a00b1 Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gentle bump @stevengj, this would help with type-inference downstream in FastTransforms.jl

@stevengj
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/85320

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.7.0 -m "<description of version>" e4a00b1d410a78288d4c3a4fd19202a78536458b
git push origin v1.7.0

Please sign in to comment.