diff --git a/repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs b/repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs index 26629872..22cadf86 100644 --- a/repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs +++ b/repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs @@ -1,8 +1,8 @@ {-# LANGUAGE TypeOperators, PatternGuards, RankNTypes, ScopedTypeVariables, BangPatterns, FlexibleContexts #-} {-# OPTIONS -fno-warn-incomplete-patterns #-} --- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm. --- Time complexity is O(n log n) in the size of the input. +-- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm. +-- Time complexity is O(n log n) in the size of the input. -- -- This uses a naive divide-and-conquer algorithm, the absolute performance is about -- 50x slower than FFTW in estimate mode. @@ -38,8 +38,7 @@ signOfMode mode {-# INLINE signOfMode #-} --- | Check if an `Int` is a power of two. Assumes `n` is positive --- and will return `True` when given 0. +-- | Check if an `Int` is a power of two. Assumes `n` is a natural number. -- The implementation can be found in Henry S. Warren, Jr.'s book -- Hacker's delight, Chapter 2. @@ -57,29 +56,29 @@ fft3dP :: (Source r Complex, Monad m) fft3dP mode arr = let _ :. depth :. height :. width = extent arr !sign = signOfMode mode - !scale = fromIntegral (depth * width * height) - + !scale = fromIntegral (depth * width * height) + in if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width) then error $ unlines [ "Data.Array.Repa.Algorithms.FFT: fft3d" , " Array dimensions must be powers of two," - , " but the provided array is " + , " but the provided array is " P.++ show height P.++ "x" P.++ show width P.++ "x" P.++ show depth ] - - else arr `deepSeqArray` + + else arr `deepSeqArray` case mode of Forward -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr Reverse -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr Inverse -> computeP - $ R.map (/ scale) + $ R.map (/ scale) $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr {-# INLINE fft3dP #-} -fftTrans3d +fftTrans3d :: Source r Complex => Double - -> Array r DIM3 Complex + -> Array r DIM3 Complex -> Array U DIM3 Complex fftTrans3d sign arr @@ -88,7 +87,7 @@ fftTrans3d sign arr {-# INLINE fftTrans3d #-} -rotate3d +rotate3d :: Source r Complex => Array r DIM3 Complex -> Array D DIM3 Complex rotate3d arr @@ -108,15 +107,15 @@ fft2dP :: (Source r Complex, Monad m) fft2dP mode arr = let _ :. height :. width = extent arr sign = signOfMode mode - scale = fromIntegral (width * height) - + scale = fromIntegral (width * height) + in if not (isPowerOfTwo height && isPowerOfTwo width) then error $ unlines [ "Data.Array.Repa.Algorithms.FFT: fft2d" , " Array dimensions must be powers of two," , " but the provided array is " P.++ show height P.++ "x" P.++ show width ] - - else arr `deepSeqArray` + + else arr `deepSeqArray` case mode of Forward -> now $ fftTrans2d sign $ fftTrans2d sign arr Reverse -> now $ fftTrans2d sign $ fftTrans2d sign arr @@ -127,7 +126,7 @@ fft2dP mode arr fftTrans2d :: Source r Complex => Double - -> Array r DIM2 Complex + -> Array r DIM2 Complex -> Array U DIM2 Complex fftTrans2d sign arr @@ -139,20 +138,20 @@ fftTrans2d sign arr -- Vector Transform ------------------------------------------------------------------------------- -- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`. fft1dP :: (Source r Complex, Monad m) - => Mode - -> Array r DIM1 Complex + => Mode + -> Array r DIM1 Complex -> m (Array U DIM1 Complex) fft1dP mode arr = let _ :. len = extent arr sign = signOfMode mode scale = fromIntegral len - + in if not $ isPowerOfTwo len - then error $ unlines + then error $ unlines [ "Data.Array.Repa.Algorithms.FFT: fft1d" , " Array dimensions must be powers of two, " , " but the provided array is " P.++ show len ] - + else arr `deepSeqArray` case mode of Forward -> now $ fftTrans1d sign arr @@ -163,7 +162,7 @@ fft1dP mode arr fftTrans1d :: Source r Complex - => Double + => Double -> Array r DIM1 Complex -> Array U DIM1 Complex @@ -175,7 +174,7 @@ fftTrans1d sign arr -- Rank Generalised Worker ------------------------------------------------------------------------ fft :: (Shape sh, Source r Complex) - => Double -> sh -> Int + => Double -> sh -> Int -> Array r (sh :. Int) Complex -> Array U (sh :. Int) Complex @@ -184,9 +183,9 @@ fft !sign !sh !lenVec !vec where go !len !offset !stride | len == 2 = suspendedComputeP $ fromFunction (sh :. 2) swivel - + | otherwise - = combine len + = combine len (go (len `div` 2) offset (stride * 2)) (go (len `div` 2) (offset + stride) (stride * 2)) @@ -198,7 +197,7 @@ fft !sign !sh !lenVec !vec {-# INLINE combine #-} combine !len' evens odds = evens `deepSeqArray` odds `deepSeqArray` - let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix) + let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix) in suspendedComputeP $ (evens +^ odds') R.++ (evens -^ odds') {-# INLINE fft #-} @@ -214,4 +213,3 @@ twiddle sign k' n' where k = fromIntegral k' n = fromIntegral n' {-# INLINE twiddle #-} -