|
| 1 | +{-# LANGUAGE BangPatterns, ExistentialQuantification #-} |
| 2 | +module Fusion.Common where |
| 3 | + |
| 4 | + |
| 5 | + |
| 6 | +data Step s a = Done |
| 7 | + | Skip !s |
| 8 | + | Yield !a !s |
| 9 | + |
| 10 | + |
| 11 | +data Stream a = forall s . Stream (s -> Step s a) !s |
| 12 | + |
| 13 | +eq :: (Eq a) => Stream a -> Stream a -> Bool |
| 14 | +eq (Stream func1 st1) (Stream func2 st2) = loop (func1 st1) (func2 st2) where |
| 15 | + loop Done Done = True |
| 16 | + loop Done _ = False |
| 17 | + loop _ Done = False |
| 18 | + loop (Skip ns1) (Skip ns2) = loop (func1 ns1) (func2 ns2) |
| 19 | + loop (Skip ns1) x = loop (func1 ns1) x |
| 20 | + loop x (Skip ns2) = loop x (func2 ns2) |
| 21 | + loop (Yield a1 ns1) (Yield a2 ns2) = a1 == a1 && loop (func1 ns1) (func2 ns2) |
| 22 | + |
| 23 | +cmp :: (Ord a) => Stream a -> Stream a -> Ordering |
| 24 | +cmp (Stream func1 st1) (Stream func2 st2) = loop (func1 st1) (func2 st2) where |
| 25 | + loop Done Done = EQ |
| 26 | + loop Done _ = LT |
| 27 | + loop _ Done = GT |
| 28 | + loop (Skip ns1) (Skip ns2) = loop (func1 ns1) (func2 ns2) |
| 29 | + loop (Skip ns1) x = loop (func1 ns1) x |
| 30 | + loop x (Skip ns2) = loop x (func2 ns2) |
| 31 | + loop (Yield a1 ns1) (Yield a2 ns2) = case compare a1 a2 of |
| 32 | + EQ -> loop (func1 ns1) (func2 ns2) |
| 33 | + other -> other |
| 34 | + |
| 35 | + |
| 36 | +instance (Eq a) => Eq (Stream a) where |
| 37 | + (==) = eq |
| 38 | +instance (Ord a) => Ord (Stream a) where |
| 39 | + compare = cmp |
| 40 | + |
| 41 | +empty :: Stream a |
| 42 | +empty = Stream nextf () |
| 43 | + where nextf _ = Done |
| 44 | + |
| 45 | +singleton :: a -> Stream a |
| 46 | +singleton x = Stream nextf False |
| 47 | + where nextf False = Yield x True |
| 48 | + nextf True = Done |
| 49 | + |
| 50 | +streamList :: [a] -> Stream a |
| 51 | +streamList s = Stream nextf s |
| 52 | + where nextf [] = Done |
| 53 | + nextf (x:xs) = Yield x xs |
| 54 | + |
| 55 | +unstreamList :: Stream a -> [a] |
| 56 | +unstreamList (Stream next s) = unfold s |
| 57 | + where unfold !s = case next s of |
| 58 | + Done -> [] |
| 59 | + Skip s' -> unfold s' |
| 60 | + Yield x s' -> x : unfold s' |
| 61 | + |
| 62 | +{-# RULES "STREAM streamList/unstreamList fusion" forall s. streamList (unstreamList s) = s #-} |
| 63 | + |
| 64 | + |
| 65 | +data C s = C0 !s |
| 66 | + | C1 !s |
| 67 | + |
| 68 | +cons :: a -> Stream a -> Stream a |
| 69 | +cons !w (Stream nextf0 st) = Stream nextf (C1 st) |
| 70 | + where nextf (C1 s) = Yield w (C0 s) |
| 71 | + nextf (C0 s) = case nextf0 s of |
| 72 | + Done -> Done |
| 73 | + Skip s' -> Skip (C0 s') |
| 74 | + Yield x s' -> Yield x (C0 s') |
| 75 | + |
| 76 | +streamLast :: Stream a -> a |
| 77 | +streamLast (Stream next s0) = loop0_last s0 |
| 78 | + where |
| 79 | + loop0_last !s = case next s of |
| 80 | + Done -> emptyError "last" |
| 81 | + Skip s' -> loop0_last s' |
| 82 | + Yield x s' -> loop_last x s' |
| 83 | + loop_last !x !s = case next s of |
| 84 | + Done -> x |
| 85 | + Skip s' -> loop_last x s' |
| 86 | + Yield x' s' -> loop_last x' s' |
| 87 | +{-# INLINE streamLast #-} |
| 88 | + |
| 89 | +streamEnumFromTo :: (Enum a) => a -> a -> Stream a |
| 90 | +streamEnumFromTo from to = Stream nextf (fromEnum from) |
| 91 | + where !enumTo = fromEnum to |
| 92 | + nextf i = if i > enumTo |
| 93 | + then Done |
| 94 | + else Yield (toEnum i) (i + 1) |
| 95 | + |
| 96 | +streamForM :: Monad m => (a -> m ()) -> Stream a -> m () |
| 97 | +streamForM func (Stream nextf s0) = loop s0 where |
| 98 | + loop !s = case nextf s of |
| 99 | + Done -> return () |
| 100 | + Skip s' -> loop s' |
| 101 | + Yield v s' -> do { func v; loop s'; } |
| 102 | + |
| 103 | +streamError :: String -> String -> a |
| 104 | +streamError func msg = error $ "Fusion.Common." ++ func ++ ": " ++ msg |
| 105 | + |
| 106 | +emptyError :: String -> a |
| 107 | +emptyError func = internalError func "Empty input" |
| 108 | + |
| 109 | +internalError :: String -> a |
| 110 | +internalError func = streamError func "Internal error" |
0 commit comments