Skip to content

Commit

Permalink
Name "unexplode" and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Nov 12, 2024
1 parent 954bf8e commit f80f443
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/akimbo/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,19 +334,33 @@ def pack(self):

def unpack(self):
"""Make dataframe out of a series of record type"""
# TODO: what to do when passed a dataframe, partial unpack of record fields?
arr = self.array
if not arr.fields:
raise ValueError("Not array-of-records")
# TODO: partial unpack when (some) fields are records
out = {k: self.to_output(arr[k]) for k in arr.fields}
return self.dataframe_type(out)

def group_lists(self, *cols, outname="grouped"):
def unexplode(self, *cols, outname="grouped"):
"""Repack "exploded" form dataframes into lists of structs
This is the inverse of the regular dataframe explode() process.
"""
# TODO: this does not work on cuDF as here we use arrow directly
# TODO: pandas indexes are pre-grouped cat-like structures
cols = list(cols)
outcols = [(_, "list") for _ in self.arrow.column_names if _ not in cols]
arr = self.arrow
if set(cols) - set(arr.column_names):
raise ValueError(
"One or more rouping column (%s) not in available columns %s",
cols,
arr.column_names,
)
outcols = [(_, "list") for _ in arr.column_names if _ not in cols]
if not outcols:
raise ValueError("Cannot group on all available columns")
outcols2 = [f"{_[0]}_list" for _ in outcols]
grouped = self.arrow.group_by(cols).aggregate(outcols)
grouped = arr.group_by(cols).aggregate(outcols)
akarr = ak.from_arrow(grouped)
akarr2 = akarr[outcols2]
akarr2.layout._fields = [_[0] for _ in outcols]
Expand Down
4 changes: 4 additions & 0 deletions src/akimbo/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ def _arrow_to_series(cls, arr):
@classmethod
def to_arrow(cls, data):
return data.to_arrow()

def pack(self):
# polars already implements this directly
return self._obj.to_struct()
39 changes: 39 additions & 0 deletions tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,42 @@ def test_rename():

s2 = s.ak.rename(("a", "b", "c"), "d")
assert s2.tolist() == [{"a": [{"b": {"d": 0}}] * 2}] * 3


def test_unexplode():
df = pd.DataFrame(
{
"x": [1, 1, 1, 2, 1, 3, 3, 1],
"y": [1, 1, 1, 2, 1, 3, 3, 1],
"z": [1, 1, 1, 2, 1, 3, 3, 2],
}
)
out = df.ak.unexplode("x")
compact = out["grouped"].tolist()
expected = [
[
{"y": 1, "z": 1},
{"y": 1, "z": 1},
{"y": 1, "z": 1},
{"y": 1, "z": 1},
{"y": 1, "z": 2},
],
[{"y": 2, "z": 2}],
[{"y": 3, "z": 3}, {"y": 3, "z": 3}],
]
assert compact == expected

out = df.ak.unexplode("x", "y")
compact = out["grouped"].tolist()
expected = [
[{"z": 1}, {"z": 1}, {"z": 1}, {"z": 1}, {"z": 2}],
[{"z": 2}],
[{"z": 3}, {"z": 3}],
]
assert compact == expected

with pytest.raises(ValueError):
df.ak.unexplode("x", "y", "z")

with pytest.raises(ValueError):
df.ak.unexplode("unknown")
39 changes: 39 additions & 0 deletions tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,42 @@ def test_ufunc():
df = pl.DataFrame({"a": s})
df2 = df.ak + 1
assert df2["a"].to_list() == [[2, 3, 4], [], [5, 6]]


def test_unexplode():
df = pl.DataFrame(
{
"x": [1, 1, 1, 2, 1, 3, 3, 1],
"y": [1, 1, 1, 2, 1, 3, 3, 1],
"z": [1, 1, 1, 2, 1, 3, 3, 2],
}
)
out = df.ak.unexplode("x")
compact = out["grouped"].to_list()
expected = [
[
{"y": 1, "z": 1},
{"y": 1, "z": 1},
{"y": 1, "z": 1},
{"y": 1, "z": 1},
{"y": 1, "z": 2},
],
[{"y": 2, "z": 2}],
[{"y": 3, "z": 3}, {"y": 3, "z": 3}],
]
assert compact == expected

out = df.ak.unexplode("x", "y")
compact = out["grouped"].to_list()
expected = [
[{"z": 1}, {"z": 1}, {"z": 1}, {"z": 1}, {"z": 2}],
[{"z": 2}],
[{"z": 3}, {"z": 3}],
]
assert compact == expected

with pytest.raises(ValueError):
df.ak.unexplode("x", "y", "z")

with pytest.raises(ValueError):
df.ak.unexplode("unknown")

0 comments on commit f80f443

Please sign in to comment.