Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove recursive pow2 check #10

Merged
merged 2 commits into from
Aug 21, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix up comment
  • Loading branch information
Alex Mason authored and Alex Mason committed Feb 7, 2017
commit 5d5a037e447509b18bedfb462e0f78867173f114
56 changes: 27 additions & 29 deletions repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{-# LANGUAGE TypeOperators, PatternGuards, RankNTypes, ScopedTypeVariables, BangPatterns, FlexibleContexts #-}
{-# OPTIONS -fno-warn-incomplete-patterns #-}

-- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm.
-- Time complexity is O(n log n) in the size of the input.
-- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm.
-- Time complexity is O(n log n) in the size of the input.
--
-- This uses a naive divide-and-conquer algorithm, the absolute performance is about
-- 50x slower than FFTW in estimate mode.
Expand Down Expand Up @@ -38,8 +38,7 @@ signOfMode mode
{-# INLINE signOfMode #-}


-- | Check if an `Int` is a power of two. Assumes `n` is positive
-- and will return `True` when given 0.
-- | Check if an `Int` is a power of two. Assumes `n` is a natural number.

-- The implementation can be found in Henry S. Warren, Jr.'s book
-- Hacker's delight, Chapter 2.
Expand All @@ -57,29 +56,29 @@ fft3dP :: (Source r Complex, Monad m)
fft3dP mode arr
= let _ :. depth :. height :. width = extent arr
!sign = signOfMode mode
!scale = fromIntegral (depth * width * height)
!scale = fromIntegral (depth * width * height)

in if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width)
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft3d"
, " Array dimensions must be powers of two,"
, " but the provided array is "
, " but the provided array is "
P.++ show height P.++ "x" P.++ show width P.++ "x" P.++ show depth ]
else arr `deepSeqArray`

else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
Reverse -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
Inverse -> computeP
$ R.map (/ scale)
$ R.map (/ scale)
$ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
{-# INLINE fft3dP #-}


fftTrans3d
fftTrans3d
:: Source r Complex
=> Double
-> Array r DIM3 Complex
-> Array r DIM3 Complex
-> Array U DIM3 Complex

fftTrans3d sign arr
Expand All @@ -88,7 +87,7 @@ fftTrans3d sign arr
{-# INLINE fftTrans3d #-}


rotate3d
rotate3d
:: Source r Complex
=> Array r DIM3 Complex -> Array D DIM3 Complex
rotate3d arr
Expand All @@ -108,15 +107,15 @@ fft2dP :: (Source r Complex, Monad m)
fft2dP mode arr
= let _ :. height :. width = extent arr
sign = signOfMode mode
scale = fromIntegral (width * height)
scale = fromIntegral (width * height)

in if not (isPowerOfTwo height && isPowerOfTwo width)
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft2d"
, " Array dimensions must be powers of two,"
, " but the provided array is " P.++ show height P.++ "x" P.++ show width ]
else arr `deepSeqArray`

else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans2d sign $ fftTrans2d sign arr
Reverse -> now $ fftTrans2d sign $ fftTrans2d sign arr
Expand All @@ -127,7 +126,7 @@ fft2dP mode arr
fftTrans2d
:: Source r Complex
=> Double
-> Array r DIM2 Complex
-> Array r DIM2 Complex
-> Array U DIM2 Complex

fftTrans2d sign arr
Expand All @@ -139,20 +138,20 @@ fftTrans2d sign arr
-- Vector Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`.
fft1dP :: (Source r Complex, Monad m)
=> Mode
-> Array r DIM1 Complex
=> Mode
-> Array r DIM1 Complex
-> m (Array U DIM1 Complex)
fft1dP mode arr
= let _ :. len = extent arr
sign = signOfMode mode
scale = fromIntegral len

in if not $ isPowerOfTwo len
then error $ unlines
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft1d"
, " Array dimensions must be powers of two, "
, " but the provided array is " P.++ show len ]

else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans1d sign arr
Expand All @@ -163,7 +162,7 @@ fft1dP mode arr

fftTrans1d
:: Source r Complex
=> Double
=> Double
-> Array r DIM1 Complex
-> Array U DIM1 Complex

Expand All @@ -175,7 +174,7 @@ fftTrans1d sign arr

-- Rank Generalised Worker ------------------------------------------------------------------------
fft :: (Shape sh, Source r Complex)
=> Double -> sh -> Int
=> Double -> sh -> Int
-> Array r (sh :. Int) Complex
-> Array U (sh :. Int) Complex

Expand All @@ -184,9 +183,9 @@ fft !sign !sh !lenVec !vec
where go !len !offset !stride
| len == 2
= suspendedComputeP $ fromFunction (sh :. 2) swivel

| otherwise
= combine len
= combine len
(go (len `div` 2) offset (stride * 2))
(go (len `div` 2) (offset + stride) (stride * 2))

Expand All @@ -198,7 +197,7 @@ fft !sign !sh !lenVec !vec
{-# INLINE combine #-}
combine !len' evens odds
= evens `deepSeqArray` odds `deepSeqArray`
let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix)
let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix)
in suspendedComputeP $ (evens +^ odds') R.++ (evens -^ odds')
{-# INLINE fft #-}

Expand All @@ -214,4 +213,3 @@ twiddle sign k' n'
where k = fromIntegral k'
n = fromIntegral n'
{-# INLINE twiddle #-}