From 4e9520050eae81de7ec67d16d386713af58ced4b Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 31 May 2024 15:28:49 +0200 Subject: [PATCH] Revert "feat(python): Add `replace_all` expression to complement `replace`" (#16630) --- .../reference/expressions/modify_select.rst | 1 - .../source/reference/series/computation.rst | 1 - py-polars/polars/_utils/various.py | 5 +- py-polars/polars/expr/expr.py | 192 +--------- py-polars/polars/series/series.py | 133 +------ .../tests/unit/operations/test_replace.py | 324 ++++++++++++++++- .../tests/unit/operations/test_replace_all.py | 343 ------------------ 7 files changed, 353 insertions(+), 646 deletions(-) delete mode 100644 py-polars/tests/unit/operations/test_replace_all.py diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index eed974efb1f7..31e82ba56fc2 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -44,7 +44,6 @@ Manipulation/selection Expr.reinterpret Expr.repeat_by Expr.replace - Expr.replace_all Expr.reshape Expr.reverse Expr.rle diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst index 3f6b324b54d5..68f3ca57e640 100644 --- a/py-polars/docs/source/reference/series/computation.rst +++ b/py-polars/docs/source/reference/series/computation.rst @@ -50,7 +50,6 @@ Computation Series.peak_min Series.rank Series.replace - Series.replace_all Series.rolling_apply Series.rolling_map Series.rolling_max diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index 225b2ecd5056..5644b1cd2a86 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -312,7 +312,10 @@ def str_duration_(td: str | None) -> int | None: .cast(tp) ) elif tp == Boolean: - cast_cols[c] = F.col(c).replace_all({"true": True, "false": False}) + cast_cols[c] = F.col(c).replace( + {"true": True, "false": False}, + default=None, + ) elif tp in INTEGER_DTYPES: int_string = F.col(c).str.replace_all(r"[^\d+-]", "") cast_cols[c] = ( diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index fa5a43266450..2a089b43e127 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -11582,26 +11582,16 @@ def replace( Accepts expression input. Sequences are parsed as Series, other non-expression inputs are parsed as literals. Length must match the length of `old` or have length 1. - default Set values that were not replaced to this value. Defaults to keeping the original value. Accepts expression input. Non-expression inputs are parsed as literals. - - .. deprecated:: 0.20.31 - Use :meth:`replace_all` instead to set a default while replacing values. - return_dtype The data type of the resulting expression. If set to `None` (default), the data type is determined automatically based on the other inputs. - .. deprecated:: 0.20.31 - Use :meth:`replace_all` instead to set a return data type while - replacing values. - See Also -------- - replace_all str.replace Notes @@ -11643,23 +11633,25 @@ def replace( └─────┴──────────┘ Passing a mapping with replacements is also supported as syntactic sugar. + Specify a default to set all values that were not matched. >>> mapping = {2: 100, 3: 200} - >>> df.with_columns(replaced=pl.col("a").replace(mapping)) + >>> df.with_columns(replaced=pl.col("a").replace(mapping, default=-1)) shape: (4, 2) ┌─────┬──────────┐ │ a ┆ replaced │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞═════╪══════════╡ - │ 1 ┆ 1 │ + │ 1 ┆ -1 │ │ 2 ┆ 100 │ │ 2 ┆ 100 │ │ 3 ┆ 200 │ └─────┴──────────┘ Replacing by values of a different data type sets the return type based on - a combination of the `new` data type and the original data type. + a combination of the `new` data type and either the original data type or the + default data type if it was set. >>> df = pl.DataFrame({"a": ["x", "y", "z"]}) >>> mapping = {"x": 1, "y": 2, "z": 3} @@ -11674,156 +11666,7 @@ def replace( │ y ┆ 2 │ │ z ┆ 3 │ └─────┴──────────┘ - - Expression input is supported. - - >>> df = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1.5, 2.5, 5.0, 1.0]}) - >>> df.with_columns( - ... replaced=pl.col("a").replace( - ... old=pl.col("a").max(), - ... new=pl.col("b").sum(), - ... ) - ... ) - shape: (4, 3) - ┌─────┬─────┬──────────┐ - │ a ┆ b ┆ replaced │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ f64 │ - ╞═════╪═════╪══════════╡ - │ 1 ┆ 1.5 ┆ 1.0 │ - │ 2 ┆ 2.5 ┆ 2.0 │ - │ 2 ┆ 5.0 ┆ 2.0 │ - │ 3 ┆ 1.0 ┆ 10.0 │ - └─────┴─────┴──────────┘ - """ - if new is no_default and isinstance(old, Mapping): - new = pl.Series(old.values()) - old = pl.Series(old.keys()) - else: - if isinstance(old, Sequence) and not isinstance(old, (str, pl.Series)): - old = pl.Series(old) - if isinstance(new, Sequence) and not isinstance(new, (str, pl.Series)): - new = pl.Series(new) - - old = parse_as_expression(old, str_as_lit=True) # type: ignore[arg-type] - new = parse_as_expression(new, str_as_lit=True) # type: ignore[arg-type] - - if default is no_default: - default = None - else: - issue_deprecation_warning( - "The `default` parameter for `replace` is deprecated." - " Use `replace_all` instead to set a default while replacing values.", - version="0.20.31", - ) - default = parse_as_expression(default, str_as_lit=True) - - if return_dtype is not None: - issue_deprecation_warning( - "The `return_dtype` parameter for `replace` is deprecated." - " Use `replace_all` instead to set a return data type while replacing values.", - version="0.20.31", - ) - - return self._from_pyexpr(self._pyexpr.replace(old, new, default, return_dtype)) - - def replace_all( - self, - old: IntoExpr | Sequence[Any] | Mapping[Any, Any], - new: IntoExpr | Sequence[Any] | NoDefault = no_default, - *, - default: IntoExpr = None, - return_dtype: PolarsDataType | None = None, - ) -> Self: - """ - Replace all values by different values. - - Parameters - ---------- - old - Value or sequence of values to replace. - Accepts expression input. Sequences are parsed as Series, - other non-expression inputs are parsed as literals. - Also accepts a mapping of values to their replacement as syntactic sugar for - `replace_all(old=Series(mapping.keys()), new=Series(mapping.values()))`. - new - Value or sequence of values to replace by. - Accepts expression input. Sequences are parsed as Series, - other non-expression inputs are parsed as literals. - Length must match the length of `old` or have length 1. - default - Set values that were not replaced to this value. Defaults to null. - Accepts expression input. Non-expression inputs are parsed as literals. - return_dtype - The data type of the resulting expression. If set to `None` (default), - the data type is determined automatically based on the other inputs. - - See Also - -------- - replace - str.replace - - Notes - ----- - The global string cache must be enabled when replacing categorical values. - - Examples - -------- - Replace a single value by another value. Values that were not replaced are set - to null. - - >>> df = pl.DataFrame({"a": [1, 2, 2, 3]}) - >>> df.with_columns(replaced=pl.col("a").replace_all(2, 100)) - shape: (4, 2) - ┌─────┬──────────┐ - │ a ┆ replaced │ - │ --- ┆ --- │ - │ i64 ┆ i32 │ - ╞═════╪══════════╡ - │ 1 ┆ null │ - │ 2 ┆ 100 │ - │ 2 ┆ 100 │ - │ 3 ┆ null │ - └─────┴──────────┘ - - Replace multiple values by passing sequences to the `old` and `new` parameters. - - >>> df.with_columns(replaced=pl.col("a").replace_all([2, 3], [100, 200])) - shape: (4, 2) - ┌─────┬──────────┐ - │ a ┆ replaced │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪══════════╡ - │ 1 ┆ null │ - │ 2 ┆ 100 │ - │ 2 ┆ 100 │ - │ 3 ┆ 200 │ - └─────┴──────────┘ - - Passing a mapping with replacements is also supported as syntactic sugar. - Specify a default to set all values that were not matched. - - >>> mapping = {2: 100, 3: 200} - >>> df.with_columns(replaced=pl.col("a").replace_all(mapping, default=-1)) - shape: (4, 2) - ┌─────┬──────────┐ - │ a ┆ replaced │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪══════════╡ - │ 1 ┆ -1 │ - │ 2 ┆ 100 │ - │ 2 ┆ 100 │ - │ 3 ┆ 200 │ - └─────┴──────────┘ - - Replacing by values of a different data type sets the return type based on - a combination of the `new` data type and the `default` data type. - - >>> df = pl.DataFrame({"a": ["x", "y", "z"]}) - >>> mapping = {"x": 1, "y": 2, "z": 3} - >>> df.with_columns(replaced=pl.col("a").replace_all(mapping)) + >>> df.with_columns(replaced=pl.col("a").replace(mapping, default=None)) shape: (3, 2) ┌─────┬──────────┐ │ a ┆ replaced │ @@ -11834,22 +11677,11 @@ def replace_all( │ y ┆ 2 │ │ z ┆ 3 │ └─────┴──────────┘ - >>> df.with_columns(replaced=pl.col("a").replace_all(mapping, default="x")) - shape: (3, 2) - ┌─────┬──────────┐ - │ a ┆ replaced │ - │ --- ┆ --- │ - │ str ┆ str │ - ╞═════╪══════════╡ - │ x ┆ 1 │ - │ y ┆ 2 │ - │ z ┆ 3 │ - └─────┴──────────┘ Set the `return_dtype` parameter to control the resulting data type directly. >>> df.with_columns( - ... replaced=pl.col("a").replace_all(mapping, return_dtype=pl.UInt8) + ... replaced=pl.col("a").replace(mapping, return_dtype=pl.UInt8) ... ) shape: (3, 2) ┌─────┬──────────┐ @@ -11866,7 +11698,7 @@ def replace_all( >>> df = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1.5, 2.5, 5.0, 1.0]}) >>> df.with_columns( - ... replaced=pl.col("a").replace_all( + ... replaced=pl.col("a").replace( ... old=pl.col("a").max(), ... new=pl.col("b").sum(), ... default=pl.col("b"), @@ -11896,7 +11728,11 @@ def replace_all( old = parse_as_expression(old, str_as_lit=True) # type: ignore[arg-type] new = parse_as_expression(new, str_as_lit=True) # type: ignore[arg-type] - default = parse_as_expression(default, str_as_lit=True) + default = ( + None + if default is no_default + else parse_as_expression(default, str_as_lit=True) + ) return self._from_pyexpr(self._pyexpr.replace(old, new, default, return_dtype)) @@ -12349,7 +12185,7 @@ def map_dict( return_dtype Set return dtype to override automatic return dtype determination. """ - return self.replace_all(mapping, default=default, return_dtype=return_dtype) + return self.replace(mapping, default=default, return_dtype=return_dtype) @classmethod def from_json(cls, value: str) -> Self: diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 518d0666031e..4f7c5e58fa6d 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -6866,27 +6866,16 @@ def replace( new Value or sequence of values to replace by. Length must match the length of `old` or have length 1. - default Set values that were not replaced to this value. Defaults to keeping the original value. Accepts expression input. Non-expression inputs are parsed as literals. - - .. deprecated:: 0.20.31 - Use :meth:`replace_all` instead to set a default while replacing values. - return_dtype - The data type of the resulting expression. If set to `None` (default), + The data type of the resulting Series. If set to `None` (default), the data type is determined automatically based on the other inputs. - .. deprecated:: 0.20.31 - Use :meth:`replace_all` instead to set a return data type while - replacing values. - - See Also -------- - replace_all str.replace Notes @@ -6921,103 +6910,11 @@ def replace( 200 ] - Passing a mapping with replacements is also supported as syntactic sugar. - - >>> mapping = {2: 100, 3: 200} - >>> s.replace(mapping) - shape: (4,) - Series: '' [i64] - [ - 1 - 100 - 100 - 200 - ] - - Replacing by values of a different data type sets the return type based on - a combination of the `new` data type and the original data type. - - >>> s = pl.Series(["x", "y", "z"]) - >>> mapping = {"x": 1, "y": 2, "z": 3} - >>> s.replace(mapping) - shape: (3,) - Series: '' [str] - [ - "1" - "2" - "3" - ] - """ - - def replace_all( - self, - old: IntoExpr | Sequence[Any] | Mapping[Any, Any], - new: IntoExpr | Sequence[Any] | NoDefault = no_default, - *, - default: IntoExpr = None, - return_dtype: PolarsDataType | None = None, - ) -> Self: - """ - Replace all values by different values. - - Parameters - ---------- - old - Value or sequence of values to replace. - Also accepts a mapping of values to their replacement as syntactic sugar for - `replace_all(old=Series(mapping.keys()), new=Series(mapping.values()))`. - new - Value or sequence of values to replace by. - Length must match the length of `old` or have length 1. - default - Set values that were not replaced to this value. Defaults to null. - Accepts expression input. Non-expression inputs are parsed as literals. - return_dtype - The data type of the resulting Series. If set to `None` (default), - the data type is determined automatically based on the other inputs. - - See Also - -------- - replace - str.replace - - Notes - ----- - The global string cache must be enabled when replacing categorical values. - - Examples - -------- - Replace a single value by another value. Values that were not replaced are set - to null. - - >>> s = pl.Series([1, 2, 2, 3]) - >>> s.replace_all(2, 100) - shape: (4,) - Series: '' [i32] - [ - null - 100 - 100 - null - ] - - Replace multiple values by passing sequences to the `old` and `new` parameters. - - >>> s.replace_all([2, 3], [100, 200]) - shape: (4,) - Series: '' [i64] - [ - null - 100 - 100 - 200 - ] - Passing a mapping with replacements is also supported as syntactic sugar. Specify a default to set all values that were not matched. >>> mapping = {2: 100, 3: 200} - >>> s.replace_all(mapping, default=-1) + >>> s.replace(mapping, default=-1) shape: (4,) Series: '' [i64] [ @@ -7027,10 +6924,11 @@ def replace_all( 200 ] + The default can be another Series. >>> default = pl.Series([2.5, 5.0, 7.5, 10.0]) - >>> s.replace_all(2, 100, default=default) + >>> s.replace(2, 100, default=default) shape: (4,) Series: '' [f64] [ @@ -7041,19 +6939,12 @@ def replace_all( ] Replacing by values of a different data type sets the return type based on - a combination of the `new` data type and the `default` data type. + a combination of the `new` data type and either the original data type or the + default data type if it was set. >>> s = pl.Series(["x", "y", "z"]) >>> mapping = {"x": 1, "y": 2, "z": 3} - >>> s.replace_all(mapping) - shape: (3,) - Series: '' [i64] - [ - 1 - 2 - 3 - ] - >>> s.replace_all(mapping, default="x") + >>> s.replace(mapping) shape: (3,) Series: '' [str] [ @@ -7061,10 +6952,18 @@ def replace_all( "2" "3" ] + >>> s.replace(mapping, default=None) + shape: (3,) + Series: '' [i64] + [ + 1 + 2 + 3 + ] Set the `return_dtype` parameter to control the resulting data type directly. - >>> s.replace_all(mapping, return_dtype=pl.UInt8) + >>> s.replace(mapping, return_dtype=pl.UInt8) shape: (3,) Series: '' [u8] [ diff --git a/py-polars/tests/unit/operations/test_replace.py b/py-polars/tests/unit/operations/test_replace.py index f2d1fa7f4681..cd5f80a73364 100644 --- a/py-polars/tests/unit/operations/test_replace.py +++ b/py-polars/tests/unit/operations/test_replace.py @@ -1,10 +1,12 @@ from __future__ import annotations +import contextlib from typing import Any import pytest import polars as pl +from polars.exceptions import CategoricalRemappingWarning from polars.testing import assert_frame_equal, assert_series_equal @@ -25,6 +27,44 @@ def test_replace_str_to_str(str_mapping: dict[str | None, str]) -> None: assert_frame_equal(result, expected) +def test_replace_str_to_str_default_self(str_mapping: dict[str | None, str]) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select( + replaced=pl.col("country_code").replace( + str_mapping, default=pl.col("country_code") + ) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_str_to_str_default_null(str_mapping: dict[str | None, str]) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select( + replaced=pl.col("country_code").replace(str_mapping, default=None) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", None, "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_str_to_str_default_other(str_mapping: dict[str | None, str]) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + + result = df.with_row_index().select( + replaced=pl.col("country_code").replace(str_mapping, default=pl.col("index")) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "2", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_str_to_cat() -> None: + s = pl.Series(["a", "b", "c"]) + mapping = {"a": "c", "b": "d"} + result = s.replace(mapping, return_dtype=pl.Categorical) + expected = pl.Series(["c", "d", "c"], dtype=pl.Categorical) + assert_series_equal(result, expected, categorical_as_str=True) + + def test_replace_enum() -> None: dtype = pl.Enum(["a", "b", "c", "d"]) s = pl.Series(["a", "b", "c"], dtype=dtype) @@ -47,6 +87,19 @@ def test_replace_enum_to_str() -> None: assert_series_equal(result, expected) +def test_replace_enum_to_new_enum() -> None: + s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) + old = ["a", "b"] + + new_dtype = pl.Enum(["a", "b", "c", "d", "e"]) + new = pl.Series(["c", "e"], dtype=new_dtype) + + result = s.replace(old, new, return_dtype=new_dtype) + + expected = pl.Series(["c", "e", "c"], dtype=new_dtype) + assert_series_equal(result, expected) + + @pl.StringCache() def test_replace_cat_to_cat(str_mapping: dict[str | None, str]) -> None: lf = pl.LazyFrame( @@ -112,6 +165,42 @@ def test_replace_int_to_str_with_null() -> None: assert_frame_equal(result, expected) +def test_replace_int_to_int_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + result = df.select( + replaced=pl.col("int").replace(mapping, default=pl.lit(6).cast(pl.Int16)) + ) + expected = pl.DataFrame( + {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int16} + ) + assert_frame_equal(result, expected) + + +def test_replace_int_to_int_null_default_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + result = df.select(replaced=pl.col("int").replace(mapping, default=None)) + expected = pl.DataFrame( + {"replaced": [None, None, None, None]}, schema={"replaced": pl.Null} + ) + assert_frame_equal(result, expected) + + +def test_replace_int_to_int_null_return_dtype() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + + result = df.select( + replaced=pl.col("int").replace(mapping, default=6, return_dtype=pl.Int32) + ) + + expected = pl.DataFrame( + {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int32} + ) + assert_frame_equal(result, expected) + + def test_replace_empty_mapping() -> None: df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) mapping: dict[Any, Any] = {} @@ -119,6 +208,14 @@ def test_replace_empty_mapping() -> None: assert_frame_equal(result, df) +def test_replace_empty_mapping_default() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping: dict[Any, Any] = {} + result = df.select(pl.col("int").replace(mapping, default=pl.lit("A"))) + expected = pl.DataFrame({"int": ["A", "A", "A", "A"]}) + assert_frame_equal(result, expected) + + def test_replace_mapping_different_dtype_str_int() -> None: df = pl.DataFrame({"int": [None, "1", None, "3"]}) mapping = {1: "b", 3: "d"} @@ -153,6 +250,60 @@ def test_replace_str_to_str_replace_all() -> None: assert_frame_equal(result, expected) +def test_replace_int_to_int_df() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) + mapping = {1: 11, 2: 22} + + result = lf.select( + pl.col("a").replace( + old=pl.Series(mapping.keys()), + new=pl.Series(mapping.values(), dtype=pl.UInt8), + default=pl.lit(99).cast(pl.UInt8), + ) + ) + expected = pl.LazyFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}) + assert_frame_equal(result, expected) + + +def test_replace_str_to_int_fill_null() -> None: + lf = pl.LazyFrame({"a": ["one", "two"]}) + mapping = {"one": 1} + + result = lf.select( + pl.col("a") + .replace(mapping, default=None, return_dtype=pl.UInt32) + .fill_null(999) + ) + + expected = pl.LazyFrame({"a": pl.Series([1, 999], dtype=pl.UInt32)}) + assert_frame_equal(result, expected) + + +def test_replace_mix() -> None: + df = pl.DataFrame( + [ + pl.Series("float_to_boolean", [1.0, None]), + pl.Series("boolean_to_int", [True, False]), + pl.Series("boolean_to_str", [True, False]), + ] + ) + + result = df.with_columns( + pl.col("float_to_boolean").replace({1.0: True}, default=None), + pl.col("boolean_to_int").replace({True: 1, False: 0}), + pl.col("boolean_to_str").replace({True: "1", False: "0"}), + ) + + expected = pl.DataFrame( + [ + pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), + pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), + pl.Series("boolean_to_str", ["1", "0"], dtype=pl.String), + ] + ) + assert_frame_equal(result, expected) + + @pytest.fixture(scope="module") def int_mapping() -> dict[int, int]: return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} @@ -165,6 +316,20 @@ def test_replace_int_to_int1(int_mapping: dict[int, int]) -> None: assert_series_equal(result, expected) +def test_replace_int_to_int2(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5]) + result = s.replace(int_mapping, default=None) + expected = pl.Series([11, None, None, None, None], dtype=pl.Int64) + assert_series_equal(result, expected) + + +def test_replace_int_to_int3(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace(int_mapping, default=9) + expected = pl.Series([11, 9, 9, 9, 9], dtype=pl.Int64) + assert_series_equal(result, expected) + + def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None: s = pl.Series([-1, 22, None, 44, -5]) result = s.replace(int_mapping) @@ -172,12 +337,33 @@ def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None: assert_series_equal(result, expected) -# https://github.com/pola-rs/polars/issues/12728 -def test_replace_str_to_int2() -> None: - s = pl.Series(["a", "b"]) - mapping = {"a": 1, "b": 2} +def test_replace_int_to_int4_return_dtype(int_mapping: dict[int, int]) -> None: + s = pl.Series([-1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace(int_mapping, return_dtype=pl.Float32) + expected = pl.Series([-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32) + assert_series_equal(result, expected) + + +def test_replace_int_to_int5_return_dtype(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace(int_mapping, default=9, return_dtype=pl.Float32) + expected = pl.Series([11.0, 9.0, 9.0, 9.0, 9.0], dtype=pl.Float32) + assert_series_equal(result, expected) + + +def test_replace_bool_to_int() -> None: + s = pl.Series([True, False, False, None]) + mapping = {True: 1, False: 0} result = s.replace(mapping) - expected = pl.Series(["1", "2"]) + expected = pl.Series([1, 0, 0, None]) + assert_series_equal(result, expected) + + +def test_replace_bool_to_str() -> None: + s = pl.Series([True, False, False, None]) + mapping = {True: "1", False: "0"} + result = s.replace(mapping) + expected = pl.Series(["1", "0", "0", None]) assert_series_equal(result, expected) @@ -189,6 +375,51 @@ def test_replace_str_to_bool_without_default() -> None: assert_series_equal(result, expected) +def test_replace_str_to_bool_with_default() -> None: + s = pl.Series(["True", "False", "False", None]) + mapping = {"True": True, "False": False} + result = s.replace(mapping, default=None) + expected = pl.Series([True, False, False, None]) + assert_series_equal(result, expected) + + +def test_replace_int_to_str() -> None: + s = pl.Series("a", [-1, 2, None, 4, -5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + result = s.replace(mapping) + + expected = pl.Series("a", ["-1", "two", None, "four", "-5"]) + assert_series_equal(result, expected) + + +def test_replace_int_to_str_with_default() -> None: + s = pl.Series("a", [1, 2, None, 4, 5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + result = s.replace(mapping, default="?") + + expected = pl.Series("a", ["one", "two", "?", "four", "five"]) + assert_series_equal(result, expected) + + +# https://github.com/pola-rs/polars/issues/12728 +def test_replace_str_to_int2() -> None: + s = pl.Series(["a", "b"]) + mapping = {"a": 1, "b": 2} + result = s.replace(mapping) + expected = pl.Series(["1", "2"]) + assert_series_equal(result, expected) + + +def test_replace_str_to_int_with_default() -> None: + s = pl.Series(["a", "b"]) + mapping = {"a": 1, "b": 2} + result = s.replace(mapping, default=None) + expected = pl.Series([1, 2]) + assert_series_equal(result, expected) + + def test_replace_old_new() -> None: s = pl.Series([1, 2, 2, 3]) result = s.replace(2, 9) @@ -238,6 +469,20 @@ def test_replace_fast_path_many_to_one() -> None: assert_frame_equal(result, expected) +def test_replace_fast_path_many_to_one_default() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace([2, 3], 100, default=-1)) + expected = pl.LazyFrame({"a": [-1, 100, 100, 100]}, schema={"a": pl.Int64}) + assert_frame_equal(result, expected) + + +def test_replace_fast_path_many_to_one_null() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace([2, 3], None, default=-1)) + expected = pl.LazyFrame({"a": [-1, None, None, None]}, schema={"a": pl.Int64}) + assert_frame_equal(result, expected) + + @pytest.mark.parametrize( ("old", "new"), [ @@ -260,3 +505,72 @@ def test_replace_duplicates_new() -> None: result = s.replace([1, 2], [100, 100]) expected = s = pl.Series([100, 100, 3, 100, 3]) assert_series_equal(result, expected) + + +def test_map_dict_deprecated() -> None: + s = pl.Series("a", [1, 2, 3]) + with pytest.deprecated_call(): + result = s.map_dict({2: 100}) + expected = pl.Series("a", [None, 100, None]) + assert_series_equal(result, expected) + + with pytest.deprecated_call(): + result = s.to_frame().select(pl.col("a").map_dict({2: 100})).to_series() + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("context", "dtype"), + [ + (pl.StringCache(), pl.Categorical), + (pytest.warns(CategoricalRemappingWarning), pl.Categorical), + (contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])), + ], +) +def test_replace_cat_str( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] + dtype: pl.DataType, +) -> None: + with context: + for old, new, expected in [ + ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), + (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + ( + pl.Series(["a", "b"], dtype=dtype), + ["c", "d"], + pl.Series("s", ["c", "d"], dtype=pl.Utf8), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace(old, new, default=None) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace(old, new, default="OTHER") # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) + + +@pytest.mark.parametrize( + "context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)] +) +def test_replace_cat_cat( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] +) -> None: + with context: + dt = pl.Categorical + for old, new, expected in [ + ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), + ( + ["a", "b"], + pl.Series(["c", "d"], dtype=dt), + pl.Series("s", ["c", "d"], dtype=dt), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace(old, new, default=None) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) diff --git a/py-polars/tests/unit/operations/test_replace_all.py b/py-polars/tests/unit/operations/test_replace_all.py deleted file mode 100644 index 844c2488b953..000000000000 --- a/py-polars/tests/unit/operations/test_replace_all.py +++ /dev/null @@ -1,343 +0,0 @@ -from __future__ import annotations - -import contextlib -from typing import Any - -import pytest - -import polars as pl -from polars.exceptions import CategoricalRemappingWarning -from polars.testing import assert_frame_equal, assert_series_equal - - -@pytest.fixture(scope="module") -def str_mapping() -> dict[str | None, str]: - return { - "CA": "Canada", - "DE": "Germany", - "FR": "France", - None: "Not specified", - } - - -def test_replace_all_fast_path_many_to_one_default() -> None: - lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) - result = lf.select(pl.col("a").replace_all([2, 3], 100, default=-1)) - expected = pl.LazyFrame({"a": [-1, 100, 100, 100]}, schema={"a": pl.Int64}) - assert_frame_equal(result, expected) - - -def test_replace_all_fast_path_many_to_one_null() -> None: - lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) - result = lf.select(pl.col("a").replace_all([2, 3], None, default=-1)) - expected = pl.LazyFrame({"a": [-1, None, None, None]}, schema={"a": pl.Int64}) - assert_frame_equal(result, expected) - - -def test_replace_all_str_to_str_default_self( - str_mapping: dict[str | None, str], -) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - result = df.select( - replaced=pl.col("country_code").replace_all( - str_mapping, default=pl.col("country_code") - ) - ) - expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_all_str_to_str_default_null( - str_mapping: dict[str | None, str], -) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - result = df.select(replaced=pl.col("country_code").replace_all(str_mapping)) - expected = pl.DataFrame({"replaced": ["France", "Not specified", None, "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_all_str_to_str_default_other( - str_mapping: dict[str | None, str], -) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - - result = df.with_row_index().select( - replaced=pl.col("country_code").replace_all( - str_mapping, default=pl.col("index") - ) - ) - expected = pl.DataFrame({"replaced": ["France", "Not specified", "2", "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_cat() -> None: - s = pl.Series(["a", "b", "c"]) - mapping = {"a": "c", "b": "d"} - result = s.replace_all(mapping, return_dtype=pl.Categorical) - expected = pl.Series(["c", "d", None], dtype=pl.Categorical) - assert_series_equal(result, expected, categorical_as_str=True) - - -def test_replace_all_enum_to_new_enum() -> None: - s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) - old = ["a", "b"] - - new_dtype = pl.Enum(["a", "b", "c", "d", "e"]) - new = pl.Series(["c", "e"], dtype=new_dtype) - - result = s.replace_all(old, new, return_dtype=new_dtype) - - expected = pl.Series(["c", "e", None], dtype=new_dtype) - assert_series_equal(result, expected) - - -def test_replace_all_int_to_int_null() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - result = df.select( - replaced=pl.col("int").replace_all(mapping, default=pl.lit(6).cast(pl.Int16)) - ) - expected = pl.DataFrame( - {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int16} - ) - assert_frame_equal(result, expected) - - -def test_replace_all_int_to_int_null_default_null() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - result = df.select(replaced=pl.col("int").replace_all(mapping)) - expected = pl.DataFrame( - {"replaced": [None, None, None, None]}, schema={"replaced": pl.Null} - ) - assert_frame_equal(result, expected) - - -def test_replace_all_int_to_int_null_return_dtype() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - - result = df.select( - replaced=pl.col("int").replace_all(mapping, default=6, return_dtype=pl.Int32) - ) - - expected = pl.DataFrame( - {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int32} - ) - assert_frame_equal(result, expected) - - -def test_replace_all_empty_mapping_default() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping: dict[Any, Any] = {} - result = df.select(pl.col("int").replace_all(mapping, default=pl.lit("A"))) - expected = pl.DataFrame({"int": ["A", "A", "A", "A"]}) - assert_frame_equal(result, expected) - - -def test_replace_all_int_to_int_df() -> None: - lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) - mapping = {1: 11, 2: 22} - - result = lf.select( - pl.col("a").replace_all( - old=pl.Series(mapping.keys()), - new=pl.Series(mapping.values(), dtype=pl.UInt8), - default=pl.lit(99).cast(pl.UInt8), - ) - ) - expected = pl.LazyFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}) - assert_frame_equal(result, expected) - - -def test_replace_all_str_to_int_fill_null() -> None: - lf = pl.LazyFrame({"a": ["one", "two"]}) - mapping = {"one": 1} - - result = lf.select( - pl.col("a") - .replace_all(mapping, default=None, return_dtype=pl.UInt32) - .fill_null(999) - ) - - expected = pl.LazyFrame({"a": pl.Series([1, 999], dtype=pl.UInt32)}) - assert_frame_equal(result, expected) - - -def test_replace_mix() -> None: - df = pl.DataFrame( - [ - pl.Series("float_to_boolean", [1.0, None]), - pl.Series("boolean_to_int", [True, False]), - pl.Series("boolean_to_str", [True, False]), - ] - ) - - result = df.with_columns( - pl.col("float_to_boolean").replace_all({1.0: True}), - pl.col("boolean_to_int").replace_all({True: 1, False: 0}), - pl.col("boolean_to_str").replace_all({True: "1", False: "0"}), - ) - - expected = pl.DataFrame( - [ - pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), - pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), - pl.Series("boolean_to_str", ["1", "0"], dtype=pl.String), - ] - ) - assert_frame_equal(result, expected) - - -@pytest.fixture(scope="module") -def int_mapping() -> dict[int, int]: - return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} - - -def test_replace_all_int_to_int2(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5]) - result = s.replace_all(int_mapping) - expected = pl.Series([11, None, None, None, None], dtype=pl.Int64) - assert_series_equal(result, expected) - - -def test_replace_all_int_to_int3(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace_all(int_mapping, default=9) - expected = pl.Series([11, 9, 9, 9, 9], dtype=pl.Int64) - assert_series_equal(result, expected) - - -def test_replace_all_int_to_int4_return_dtype(int_mapping: dict[int, int]) -> None: - s = pl.Series([-1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace_all(int_mapping, default=s, return_dtype=pl.Float32) - expected = pl.Series([-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32) - assert_series_equal(result, expected) - - -def test_replace_all_int_to_int5_return_dtype(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace_all(int_mapping, default=9, return_dtype=pl.Float32) - expected = pl.Series([11.0, 9.0, 9.0, 9.0, 9.0], dtype=pl.Float32) - assert_series_equal(result, expected) - - -def test_replace_all_bool_to_int() -> None: - s = pl.Series([True, False, False, None]) - mapping = {True: 1, False: 0} - result = s.replace_all(mapping) - expected = pl.Series([1, 0, 0, None]) - assert_series_equal(result, expected) - - -def test_replace_bool_to_str() -> None: - s = pl.Series([True, False, False, None]) - mapping = {True: "1", False: "0"} - result = s.replace_all(mapping) - expected = pl.Series(["1", "0", "0", None]) - assert_series_equal(result, expected) - - -def test_replace_str_to_bool_with_default() -> None: - s = pl.Series(["True", "False", "False", None]) - mapping = {"True": True, "False": False} - result = s.replace_all(mapping) - expected = pl.Series([True, False, False, None]) - assert_series_equal(result, expected) - - -def test_replace_int_to_str() -> None: - s = pl.Series("a", [-1, 2, None, 4, -5]) - mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} - - result = s.replace_all(mapping) - - expected = pl.Series("a", [None, "two", None, "four", None]) - assert_series_equal(result, expected) - - -def test_replace_int_to_str_with_default() -> None: - s = pl.Series("a", [1, 2, None, 4, 5]) - mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} - - result = s.replace_all(mapping, default="?") - - expected = pl.Series("a", ["one", "two", "?", "four", "five"]) - assert_series_equal(result, expected) - - -def test_replace_all_str_to_int() -> None: - s = pl.Series(["a", "b"]) - mapping = {"a": 1, "b": 2} - result = s.replace_all(mapping) - expected = pl.Series([1, 2]) - assert_series_equal(result, expected) - - -@pytest.mark.parametrize( - ("context", "dtype"), - [ - (pl.StringCache(), pl.Categorical), - (pytest.warns(CategoricalRemappingWarning), pl.Categorical), - (contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])), - ], -) -def test_replace_cat_str( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] - dtype: pl.DataType, -) -> None: - with context: - for old, new, expected in [ - ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), - (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), - (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), - ( - pl.Series(["a", "b"], dtype=dtype), - ["c", "d"], - pl.Series("s", ["c", "d"], dtype=pl.Utf8), - ), - ]: - s = pl.Series("s", ["a", "b"], dtype=dtype) - s_replaced = s.replace_all(old, new) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected) - - s = pl.Series("s", ["a", "b"], dtype=dtype) - s_replaced = s.replace_all(old, new, default="OTHER") # type: ignore[arg-type] - assert_series_equal(s_replaced, expected.fill_null("OTHER")) - - -@pytest.mark.parametrize( - "context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)] -) -def test_replace_cat_cat( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] -) -> None: - with context: - dt = pl.Categorical - for old, new, expected in [ - ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), - ( - ["a", "b"], - pl.Series(["c", "d"], dtype=dt), - pl.Series("s", ["c", "d"], dtype=dt), - ), - ]: - s = pl.Series("s", ["a", "b"], dtype=dt) - s_replaced = s.replace_all(old, new) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected) - - s = pl.Series("s", ["a", "b"], dtype=dt) - s_replaced = s.replace_all(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected.fill_null("OTHER")) - - -def test_map_dict_deprecated() -> None: - s = pl.Series("a", [1, 2, 3]) - with pytest.deprecated_call(): - result = s.map_dict({2: 100}) - expected = pl.Series("a", [None, 100, None]) - assert_series_equal(result, expected) - - with pytest.deprecated_call(): - result = s.to_frame().select(pl.col("a").map_dict({2: 100})).to_series() - assert_series_equal(result, expected)