Skip to content

Commit

Permalink
fix(python): Consistent expansion of nested struct data during `DataF…
Browse files Browse the repository at this point in the history
…rame` init from dict (#15217)
  • Loading branch information
alexander-beedie authored Mar 24, 2024
1 parent a352ee0 commit 2484dd2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 12 deletions.
21 changes: 16 additions & 5 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def dict_to_pydf(
else:
data_series = [
s._s
for s in _expand_dict_scalars(
for s in _expand_dict_values(
data,
schema_overrides=schema_overrides,
strict=strict,
Expand Down Expand Up @@ -310,7 +310,7 @@ def _post_apply_columns(
return pydf


def _expand_dict_scalars(
def _expand_dict_values(
data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series],
*,
schema_overrides: SchemaDict | None = None,
Expand All @@ -337,9 +337,20 @@ def _expand_dict_scalars(
for name, val in data.items():
dtype = dtypes.get(name)
if isinstance(val, dict) and dtype != Struct:
updated_data[name] = pl.DataFrame(val, strict=strict).to_struct(
name
)
vdf = pl.DataFrame(val, strict=strict)
if (
len(vdf) == 1
and array_len > 1
and all(not d.is_nested() for d in vdf.schema.values())
):
s_vals = {
nm: vdf[nm].extend_constant(v, n=(array_len - 1))
for nm, v in val.items()
}
st = pl.DataFrame(s_vals).to_struct(name)
else:
st = vdf.to_struct(name)
updated_data[name] = st

elif isinstance(val, pl.Series):
s = val.rename(name) if name != val.name else val
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/_utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def _in_notebook() -> bool:


def arrlen(obj: Any) -> int | None:
"""Return length of (non-string) sequence object; returns None for non-sequences."""
"""Return length of (non-string/dict) sequence; returns None for non-sequences."""
try:
return None if isinstance(obj, str) else len(obj)
return None if isinstance(obj, (str, dict)) else len(obj)
except TypeError:
return None

Expand Down
41 changes: 36 additions & 5 deletions py-polars/tests/unit/dataframe/test_from_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_from_dict_with_scalars() -> None:


@pytest.mark.slow()
def test_from_dict_with_scalars_mixed() -> None:
def test_from_dict_with_values_mixed() -> None:
# a bit of everything
mixed_dtype_data: dict[str, Any] = {
"a": 0,
Expand All @@ -164,11 +164,10 @@ def test_from_dict_with_scalars_mixed() -> None:
# note: deliberately set this value large; if all dtypes are
# on the fast-path it'll only take ~0.03secs. if it becomes
# even remotely noticeable that will indicate a regression.
# TODO: This is now slow (~0.15 seconds). Needs to be looked into.
n_range = 1_000_000
index_and_data: dict[str, Any] = {"idx": range(n_range)}
index_and_data.update(mixed_dtype_data.items())
df8 = pl.DataFrame(
df = pl.DataFrame(
data=index_and_data,
schema={
"idx": pl.Int32,
Expand All @@ -185,14 +184,46 @@ def test_from_dict_with_scalars_mixed() -> None:
"k": pl.String,
},
)
dfx = df8.select(pl.exclude("idx"))
dfx = df.select(pl.exclude("idx"))

assert len(df8) == n_range
assert len(df) == n_range
assert dfx[:5].rows() == dfx[5:10].rows()
assert dfx[-10:-5].rows() == dfx[-5:].rows()
assert dfx.row(n_range // 2, named=True) == mixed_dtype_data


def test_from_dict_expand_nested_struct() -> None:
# confirm consistent init of nested struct from dict data
dt = date(2077, 10, 10)
expected = pl.DataFrame(
[
pl.Series("x", [dt]),
pl.Series("nested", [{"y": -1, "z": 1}]),
]
)
for df in (
pl.DataFrame({"x": dt, "nested": {"y": -1, "z": 1}}),
pl.DataFrame({"x": dt, "nested": [{"y": -1, "z": 1}]}),
pl.DataFrame({"x": [dt], "nested": {"y": -1, "z": 1}}),
pl.DataFrame({"x": [dt], "nested": [{"y": -1, "z": 1}]}),
):
assert_frame_equal(expected, df)

# confirm expansion to 'n' nested values
nested_values = [{"y": -1, "z": 1}, {"y": -1, "z": 1}, {"y": -1, "z": 1}]
expected = pl.DataFrame(
[
pl.Series("x", [0, 1, 2]),
pl.Series("nested", nested_values),
]
)
for df in (
pl.DataFrame({"x": range(3), "nested": {"y": -1, "z": 1}}),
pl.DataFrame({"x": [0, 1, 2], "nested": {"y": -1, "z": 1}}),
):
assert_frame_equal(expected, df)


def test_from_dict_duration_subseconds() -> None:
d = {"duration": [timedelta(seconds=1, microseconds=1000)]}
result = pl.from_dict(d)
Expand Down

0 comments on commit 2484dd2

Please sign in to comment.