Skip to content

Commit

Permalink
Implements Series.split_into/3 (#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryancurtin authored Mar 5, 2024
1 parent 8bb587f commit ad43e04
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 8 deletions.
8 changes: 8 additions & 0 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ defmodule Explorer.Backend.LazySeries do
downcase: 1,
substring: 3,
split: 2,
split_into: 3,
json_decode: 2,
json_path_match: 2,
# Float round
Expand Down Expand Up @@ -1053,6 +1054,13 @@ defmodule Explorer.Backend.LazySeries do
Backend.Series.new(data, {:list, :string})
end

@impl true
def split_into(series, by, fields) do
data = new(:split_into, [lazy_series!(series), by, fields], :string)

Backend.Series.new(data, {:struct, Enum.map(fields, &{&1, :string})})
end

@impl true
def round(series, decimals) when is_integer(decimals) and decimals >= 0 do
data = new(:round, [lazy_series!(series), decimals], {:f, 64})
Expand Down
1 change: 1 addition & 0 deletions lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ defmodule Explorer.Backend.Series do
@callback rstrip(s, String.t() | nil) :: s
@callback substring(s, integer(), non_neg_integer() | nil) :: s
@callback split(s, String.t()) :: s
@callback split_into(s, String.t(), list(String.t() | atom())) :: s
@callback json_decode(s, dtype()) :: s
@callback json_path_match(s, String.t()) :: s

Expand Down
1 change: 1 addition & 0 deletions lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ defmodule Explorer.PolarsBackend.Expression do
upcase: 1,
substring: 3,
split: 2,
split_into: 3,
json_decode: 2,
json_path_match: 2,

Expand Down
1 change: 1 addition & 0 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_cut(_s, _bins, _labels, _break_point_label, _category_label), do: err()
def s_substring(_s, _offset, _length), do: err()
def s_split(_s, _by), do: err()
def s_split_into(_s, _by, _num_fields), do: err()

def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label),
do: err()
Expand Down
4 changes: 4 additions & 0 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ defmodule Explorer.PolarsBackend.Series do
def split(series, by),
do: Shared.apply_series(series, :s_split, [by])

@impl true
def split_into(series, by, fields),
do: Shared.apply_series(series, :s_split_into, [by, fields])

# Float round
@impl true
def round(series, decimals),
Expand Down
26 changes: 26 additions & 0 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5651,6 +5651,32 @@ defmodule Explorer.Series do
def split(%Series{dtype: dtype}, _by),
do: dtype_error("split/2", dtype, [:string])

@doc """
Split a string Series into a struct of string `fields`.
The length of the field names list determines how many times the
string will be split at most. If the string cannot be split into that
many separate strings, null values will be provided for the
remaining fields.
## Examples
iex> s = Series.from_list(["Smith, John", "Jones, Jane"])
iex> Series.split_into(s, ", ", ["Last Name", "First Name"])
#Explorer.Series<
Polars[2]
struct[2] [%{"First Name" => "John", "Last Name" => "Smith"}, %{"First Name" => "Jane", "Last Name" => "Jones"}]
>
"""
@doc type: :string_wise
@spec split_into(Series.t(), String.t(), list(String.t() | atom())) :: Series.t()
def split_into(%Series{dtype: :string} = series, by, [_ | _] = fields) when is_binary(by),
do: apply_series(series, :split_into, [by, Enum.map(fields, &to_string/1)])

def split_into(%Series{dtype: dtype}, by, [_ | _]) when is_binary(by),
do: dtype_error("split_into/3", dtype, [:string])

# Float

@doc """
Expand Down
12 changes: 12 additions & 0 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,18 @@ pub fn expr_json_path_match(expr: ExExpr, json_path: &str) -> ExExpr {
ExExpr::new(expr)
}

#[rustler::nif]
pub fn expr_split_into(expr: ExExpr, by: String, names: Vec<String>) -> ExExpr {
let expr = expr
.clone_inner()
.str()
.splitn(by.lit(), names.len())
.struct_()
.rename_fields(names);

ExExpr::new(expr)
}

#[rustler::nif]
pub fn expr_struct(ex_exprs: Vec<ExExpr>) -> ExExpr {
let exprs = ex_exprs.iter().map(|e| e.clone_inner()).collect();
Expand Down
2 changes: 2 additions & 0 deletions native/explorer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ rustler::init!(
expr_split,
expr_replace,
expr_json_path_match,
expr_split_into,
// float round expressions
expr_round,
expr_floor,
Expand Down Expand Up @@ -457,6 +458,7 @@ rustler::init!(
s_strip,
s_substring,
s_split,
s_split_into,
s_subtract,
s_sum,
s_tail,
Expand Down
19 changes: 19 additions & 0 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,25 @@ pub fn s_split(s1: ExSeries, by: &str) -> Result<ExSeries, ExplorerError> {
Ok(ExSeries::new(s2))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_split_into(s1: ExSeries, by: &str, names: Vec<String>) -> Result<ExSeries, ExplorerError> {
let s2 = s1
.clone_inner()
.into_frame()
.lazy()
.select([col(s1.name())
.str()
.splitn(by.lit(), names.len())
.struct_()
.rename_fields(names)
.alias(s1.name())])
.collect()?
.column(s1.name())?
.clone();

Ok(ExSeries::new(s2))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_round(s: ExSeries, decimals: u32) -> Result<ExSeries, ExplorerError> {
Ok(ExSeries::new(s.round(decimals)?.into_series()))
Expand Down
33 changes: 25 additions & 8 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,27 @@ defmodule Explorer.DataFrameTest do
member?: [true, false]
}
end

test "splits a string column into multiple new columns" do
new_column_names = ["Last Name", "First Name"]
df = DF.new(%{names: ["Smith, John", "Jones, Jane"]})

df =
DF.mutate_with(df, fn ldf ->
%{names: Series.split_into(ldf[:names], ", ", new_column_names)}
end)
|> DF.unnest(:names)

assert DF.dtypes(df) == %{
"Last Name" => :string,
"First Name" => :string
}

assert DF.to_columns(df) == %{
"Last Name" => ["Smith", "Jones"],
"First Name" => ["John", "Jane"]
}
end
end

describe "sort_by/3" do
Expand Down Expand Up @@ -2629,17 +2650,13 @@ defmodule Explorer.DataFrameTest do
end

test "mixing nulls, signed, unsigned integers, and floats" do
df1 =
DF.new(x: Series.from_list([1, 2], dtype: :u16), y: Series.from_list(["a", "b"]))
df1 = DF.new(x: Series.from_list([1, 2], dtype: :u16), y: Series.from_list(["a", "b"]))

df2 =
DF.new(x: Series.from_list([3.0, 4.0], dtype: :f32), y: Series.from_list(["c", "d"]))
df2 = DF.new(x: Series.from_list([3.0, 4.0], dtype: :f32), y: Series.from_list(["c", "d"]))

df3 =
DF.new(x: [nil, nil], y: [nil, nil])
df3 = DF.new(x: [nil, nil], y: [nil, nil])

df4 =
DF.new(x: Series.from_list([5, 6], dtype: :s16), y: Series.from_list(["e", "f"]))
df4 = DF.new(x: Series.from_list([5, 6], dtype: :s16), y: Series.from_list(["e", "f"]))

df5 = DF.concat_rows([df1, df2, df3, df4])

Expand Down
22 changes: 22 additions & 0 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5320,6 +5320,28 @@ defmodule Explorer.SeriesTest do
end
end

describe "split_into" do
test "split_into/3 produces the correct number of fields in a struct" do
series = Series.from_list(["Smith, John", "Jones, Jane"])
split_series = series |> Series.split_into(", ", ["Last Name", "First Name"])

assert Series.to_list(split_series) == [
%{"First Name" => "John", "Last Name" => "Smith"},
%{"First Name" => "Jane", "Last Name" => "Jones"}
]
end

test "split_into/3 produces a nil field when string cannot be split for every field" do
series = Series.from_list(["Smith-John", "Jones-Jane"])
split_series = series |> Series.split_into("-", ["Last Name", "First Name", "Middle Name"])

assert Series.to_list(split_series) == [
%{"First Name" => "John", "Last Name" => "Smith", "Middle Name" => nil},
%{"First Name" => "Jane", "Last Name" => "Jones", "Middle Name" => nil}
]
end
end

describe "strptime/2 and strftime/2" do
test "parse datetime from string" do
series = Series.from_list(["2023-01-05 12:34:56", "XYZ", nil])
Expand Down

0 comments on commit ad43e04

Please sign in to comment.