Skip to content

Commit 3172f5c

Browse files
committed
feat: Avoid multiple passes when grouping dataframe.
1 parent 31d648f commit 3172f5c

File tree

1 file changed

+92
-10
lines changed

1 file changed

+92
-10
lines changed

src/DataFrame/Operations/Aggregation.hs

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE BangPatterns #-}
12
{-# LANGUAGE ExplicitNamespaces #-}
23
{-# LANGUAGE FlexibleContexts #-}
34
{-# LANGUAGE GADTs #-}
@@ -55,21 +56,102 @@ groupBy names df
5556
df
5657
names
5758
(VG.map fst valueIndices)
58-
(VU.fromList (reverse (changingPoints valueIndices)))
59+
(changingPoints valueIndices)
5960
where
6061
indicesToGroup = M.elems $ M.filterWithKey (\k _ -> k `elem` names) (columnIndices df)
61-
rowRepresentations = computeRowHashes indicesToGroup df
62-
62+
doubleToInt :: Double -> Int
63+
doubleToInt = floor . (* 1000)
6364
valueIndices = runST $ do
64-
withIndexes <- VG.thaw $ VG.indexed rowRepresentations
65-
VA.sortBy (\(a, b) (a', b') -> compare b' b) withIndexes
66-
VG.unsafeFreeze withIndexes
65+
let n = fst (dimensions df)
66+
mv <- VUM.new n
67+
68+
let selectedCols = map (columns df V.!) indicesToGroup
69+
70+
forM_ selectedCols $ \case
71+
UnboxedColumn (v :: VU.Vector a) ->
72+
case testEquality (typeRep @a) (typeRep @Int) of
73+
Just Refl ->
74+
VU.imapM_
75+
( \i (x :: Int) -> do
76+
(_, !h) <- VUM.unsafeRead mv i
77+
VUM.unsafeWrite mv i (i, hashWithSalt h x)
78+
)
79+
v
80+
Nothing ->
81+
case testEquality (typeRep @a) (typeRep @Double) of
82+
Just Refl ->
83+
VU.imapM_
84+
( \i (d :: Double) -> do
85+
(_, !h) <- VUM.unsafeRead mv i
86+
VUM.unsafeWrite mv i (i, hashWithSalt h (doubleToInt d))
87+
)
88+
v
89+
Nothing ->
90+
case sIntegral @a of
91+
STrue ->
92+
VU.imapM_
93+
( \i d -> do
94+
let x :: Int
95+
x = fromIntegral @a @Int d
96+
(_, !h) <- VUM.unsafeRead mv i
97+
VUM.unsafeWrite mv i (i, hashWithSalt h x)
98+
)
99+
v
100+
SFalse ->
101+
case sFloating @a of
102+
STrue ->
103+
VU.imapM_
104+
( \i d -> do
105+
let x :: Int
106+
x = doubleToInt (realToFrac d :: Double)
107+
(_, !h) <- VUM.unsafeRead mv i
108+
VUM.unsafeWrite mv i (i, hashWithSalt h x)
109+
)
110+
v
111+
SFalse ->
112+
VU.imapM_
113+
( \i d -> do
114+
let x = hash (show d)
115+
(_, !h) <- VUM.unsafeRead mv i
116+
VUM.unsafeWrite mv i (i, hashWithSalt h x)
117+
)
118+
v
119+
BoxedColumn (v :: V.Vector a) ->
120+
case testEquality (typeRep @a) (typeRep @T.Text) of
121+
Just Refl ->
122+
V.imapM_
123+
( \i (t :: T.Text) -> do
124+
(_, !h) <- VUM.unsafeRead mv i
125+
VUM.unsafeWrite mv i (i, hashWithSalt h t)
126+
)
127+
v
128+
Nothing ->
129+
V.imapM_
130+
( \i d -> do
131+
let x = hash (show d)
132+
(_, !h) <- VUM.unsafeRead mv i
133+
VUM.unsafeWrite mv i (i, hashWithSalt h x)
134+
)
135+
v
136+
OptionalColumn v ->
137+
V.imapM_
138+
( \i d -> do
139+
let x = hash (show d)
140+
(_, !h) <- VUM.unsafeRead mv i
141+
VUM.unsafeWrite mv i (i, hashWithSalt h x)
142+
)
143+
v
144+
145+
VA.sortBy (\(a, b) (a', b') -> compare b' b) mv
146+
VG.unsafeFreeze mv
67147

68-
changingPoints :: (Eq a, VU.Unbox a) => VU.Vector (Int, a) -> [Int]
69-
changingPoints vs = VG.length vs : fst (VU.ifoldl findChangePoints initialState vs)
148+
changingPoints :: VU.Vector (Int, Int) -> VU.Vector Int
149+
changingPoints vs =
150+
VU.reverse
151+
(VU.fromList (VG.length vs : fst (VU.ifoldl' findChangePoints initialState vs)))
70152
where
71-
initialState = ([0], snd (VG.head vs))
72-
findChangePoints (offsets, currentVal) index (_, newVal)
153+
initialState = ([0], snd (VU.head vs))
154+
findChangePoints (!offsets, !currentVal) index (_, !newVal)
73155
| currentVal == newVal = (offsets, currentVal)
74156
| otherwise = (index : offsets, newVal)
75157

0 commit comments

Comments
 (0)