Skip to content

Commit

Permalink
feat!: add kwargs that are deserializable by user defined structs (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 17, 2023
1 parent 79cf87c commit f3548ec
Show file tree
Hide file tree
Showing 17 changed files with 269 additions and 86 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Cargo.lock
.idea/
venv/
target/
rust-toolchain.toml
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ members = [
]

[workspace.dependencies]
polars = { version = "0.33.2", default-features = false }
polars-core = { version = "0.33.2", default-features = false }
polars-ffi = { version = "0.33.2", default-features = false }
polars-plan = { version = "0.33.2", default-feautres = false }
polars-lazy = { version = "0.33.2", default-features = false }
polars = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false }
polars-core = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false }
polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false }
polars-plan = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-feautres = false }
polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false }
2 changes: 2 additions & 0 deletions example/derive_expression/expression_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ name = "expression_lib"
crate-type = ["cdylib"]

[dependencies]
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }
polars = { workspace = true, features = ["fmt"], default-features = false }
polars-plan = { workspace = true, default-features = false }
pyo3 = { version = "0.20.0", features = ["extension-module"] }
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,30 @@ def pig_latinnify(self) -> pl.Expr:
is_elementwise=True,
)

def append_args(
self,
float_arg: float,
integer_arg: int,
string_arg: str,
boolean_arg: bool,
) -> pl.Expr:
"""
This example shows how arguments other than `Series` can be used.
"""
return self._expr._register_plugin(
lib=lib,
args=[],
kwargs={
"float_arg": float_arg,
"integer_arg": integer_arg,
"string_arg": string_arg,
"boolean_arg": boolean_arg,
},
symbol="append_kwargs",
is_elementwise=True,
)


@pl.api.register_expr_namespace("dist")
class Distance:
def __init__(self, expr: pl.Expr):
Expand All @@ -38,11 +62,17 @@ def jaccard_similarity(self, other: IntoExpr) -> pl.Expr:
is_elementwise=True,
)

def haversine(self, start_lat: IntoExpr, start_long: IntoExpr, end_lat: IntoExpr, end_long: IntoExpr) -> pl.Expr:
def haversine(
self,
start_lat: IntoExpr,
start_long: IntoExpr,
end_lat: IntoExpr,
end_long: IntoExpr,
) -> pl.Expr:
return self._expr._register_plugin(
lib=lib,
args=[start_lat, start_long, end_lat, end_long],
symbol="haversine",
is_elementwise=True,
cast_to_supertypes=True
cast_to_supertypes=True,
)
41 changes: 36 additions & 5 deletions example/derive_expression/expression_lib/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars::prelude::*;
use polars_plan::dsl::FieldsMapper;
use pyo3_polars::derive::polars_expr;
use pyo3_polars::derive::{polars_expr, DefaultKwargs};
use serde::Deserialize;
use std::fmt::Write;

fn pig_latin_str(value: &str, output: &mut String) {
Expand All @@ -10,21 +11,21 @@ fn pig_latin_str(value: &str, output: &mut String) {
}

#[polars_expr(output_type=Utf8)]
fn pig_latinnify(inputs: &[Series]) -> PolarsResult<Series> {
fn pig_latinnify(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
let ca = inputs[0].utf8()?;
let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str);
Ok(out.into_series())
}

#[polars_expr(output_type=Float64)]
fn jaccard_similarity(inputs: &[Series]) -> PolarsResult<Series> {
fn jaccard_similarity(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
let a = inputs[0].list()?;
let b = inputs[1].list()?;
crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series())
}

#[polars_expr(output_type=Float64)]
fn hamming_distance(inputs: &[Series]) -> PolarsResult<Series> {
fn hamming_distance(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
let a = inputs[0].utf8()?;
let b = inputs[1].utf8()?;
let out: UInt32Chunked =
Expand All @@ -37,7 +38,7 @@ fn haversine_output(input_fields: &[Field]) -> PolarsResult<Field> {
}

#[polars_expr(type_func=haversine_output)]
fn haversine(inputs: &[Series]) -> PolarsResult<Series> {
fn haversine(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
let out = match inputs[0].dtype() {
DataType::Float32 => {
let start_lat = inputs[0].f32().unwrap();
Expand All @@ -59,3 +60,33 @@ fn haversine(inputs: &[Series]) -> PolarsResult<Series> {
};
Ok(out)
}

/// The `DefaultKwargs` isn't very ergonomic as it doesn't validate any schema.
/// Provide your own kwargs struct with the proper schema and accept that type
/// in your plugin expression.
#[derive(Deserialize)]
pub struct MyKwargs {
float_arg: f64,
integer_arg: i64,
string_arg: String,
boolean_arg: bool,
}

#[polars_expr(output_type=Utf8)]
fn append_kwargs(input: &[Series], kwargs: Option<MyKwargs>) -> PolarsResult<Series> {
let input = &input[0];
let kwargs = kwargs.unwrap();
let input = input.cast(&DataType::Utf8)?;
let ca = input.utf8().unwrap();

Ok(ca
.apply_to_buffer(|val, buf| {
write!(
buf,
"{}-{}-{}-{}-{}",
val, kwargs.float_arg, kwargs.integer_arg, kwargs.string_arg, kwargs.boolean_arg
)
.unwrap()
})
.into_series())
}
4 changes: 4 additions & 0 deletions example/derive_expression/expression_lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
mod distances;
mod expressions;

#[global_allocator]
#[cfg(target_os = "linux")]
static ALLOC: Jemalloc = Jemalloc;
43 changes: 32 additions & 11 deletions example/derive_expression/run.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
import polars as pl
from expression_lib import Language, Distance

df = pl.DataFrame({
"names": ["Richard", "Alice", "Bob"],
"moons": ["full", "half", "red"],
"dist_a": [[12, 32, 1], [], [1, -2]],
"dist_b": [[-12, 1], [43], [876, -45, 9]],
"floats": [5.6, -1245.8, 242.224]
})
df = pl.DataFrame(
{
"names": ["Richard", "Alice", "Bob"],
"moons": ["full", "half", "red"],
"dist_a": [[12, 32, 1], [], [1, -2]],
"dist_b": [[-12, 1], [43], [876, -45, 9]],
"floats": [5.6, -1245.8, 242.224],
}
)


out = df.with_columns(
pig_latin = pl.col("names").language.pig_latinnify(),
pig_latin=pl.col("names").language.pig_latinnify(),
).with_columns(
hamming_dist = pl.col("names").dist.hamming_distance("pig_latin"),
jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b"),
haversine = pl.col("floats").dist.haversine("floats", "floats", "floats", "floats"),
hamming_dist=pl.col("names").dist.hamming_distance("pig_latin"),
jaccard_sim=pl.col("dist_a").dist.jaccard_similarity("dist_b"),
haversine=pl.col("floats").dist.haversine("floats", "floats", "floats", "floats"),
appended_args=pl.col("names").language.append_args(
float_arg=11.234,
integer_arg=93,
boolean_arg=False,
string_arg="example",
)
)

print(out)


# Tests we can return errors from FFI by passing wrong types.
try:
out.with_columns(
appended_args=pl.col("names").language.append_args(
float_arg=True,
integer_arg=True,
boolean_arg=True,
string_arg="example",
))
except pl.ComputeError as e:
assert "the plugin failed with message" in str(e)
13 changes: 5 additions & 8 deletions example/extend_polars_python_dispatch/extend_polars/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
mod parallel_jaccard_mod;

use pyo3::prelude::*;
use pyo3_polars::{
PyDataFrame,
PyLazyFrame,
};
use pyo3_polars::error::PyPolarsErr;
use polars::prelude::*;
use polars_lazy::frame::IntoLazy;
use polars_lazy::prelude::LazyFrame;

use pyo3::prelude::*;
use pyo3_polars::error::PyPolarsErr;
use pyo3_polars::{PyDataFrame, PyLazyFrame};

#[pyfunction]
fn parallel_jaccard(pydf: PyDataFrame, col_a: &str, col_b: &str) -> PyResult<PyDataFrame> {
Expand All @@ -21,7 +17,8 @@ fn parallel_jaccard(pydf: PyDataFrame, col_a: &str, col_b: &str) -> PyResult<PyD
#[pyfunction]
fn lazy_parallel_jaccard(pydf: PyLazyFrame, col_a: &str, col_b: &str) -> PyResult<PyLazyFrame> {
let df: LazyFrame = pydf.into();
let df = parallel_jaccard_mod::parallel_jaccard(df.collect().unwrap(), col_a, col_b).map_err(PyPolarsErr::from)?;
let df = parallel_jaccard_mod::parallel_jaccard(df.collect().unwrap(), col_a, col_b)
.map_err(PyPolarsErr::from)?;
Ok(PyLazyFrame(df.lazy()))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rayon::prelude::*;
use polars_core::utils::accumulate_dataframes_vertical;
use polars::prelude::*;
use polars_core::utils::accumulate_dataframes_vertical;
use rayon::prelude::*;

/// Create `n` splits so that we can slice a polars data structure
/// and process the chunks in parallel
Expand Down Expand Up @@ -28,42 +28,48 @@ fn compute_jaccard_similarity(sa: &Series, sb: &Series) -> PolarsResult<Series>
let sa = sa.list()?;
let sb = sb.list()?;

let ca = sa.into_iter().zip(sb.into_iter()).map(|(a, b)| {
match (a, b) {
(Some(a), Some(b)) => {
// unpack as i64 series
let a = a.i64()?;
let b = b.i64()?;
let ca = sa
.into_iter()
.zip(sb.into_iter())
.map(|(a, b)| {
match (a, b) {
(Some(a), Some(b)) => {
// unpack as i64 series
let a = a.i64()?;
let b = b.i64()?;

// convert to hashsets over Option<i64>
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();
// convert to hashsets over Option<i64>
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();

// count the number of intersections
let s3_len = s1.intersection(&s2).count();
// return similarity
Ok(Some(s3_len as f64 / (s1.len() + s2.len() - s3_len) as f64))
},
_ => Ok(None)
}
}).collect::<PolarsResult<Float64Chunked>>()?;
// count the number of intersections
let s3_len = s1.intersection(&s2).count();
// return similarity
Ok(Some(s3_len as f64 / (s1.len() + s2.len() - s3_len) as f64))
}
_ => Ok(None),
}
})
.collect::<PolarsResult<Float64Chunked>>()?;
Ok(ca.into_series())
}

pub(super) fn parallel_jaccard(df: DataFrame, col_a: &str, col_b: &str) -> PolarsResult<DataFrame> {
let offsets = split_offsets(df.height(), rayon::current_num_threads());

let dfs= offsets.par_iter().map(|(offset, len)| {
let sub_df = df.slice(*offset as i64, *len);
let a = sub_df.column(col_a)?;
let b = sub_df.column(col_b)?;
let dfs = offsets
.par_iter()
.map(|(offset, len)| {
let sub_df = df.slice(*offset as i64, *len);
let a = sub_df.column(col_a)?;
let b = sub_df.column(col_b)?;

let out= compute_jaccard_similarity(a, b)?;
let out = compute_jaccard_similarity(a, b)?;

df!(
"jaccard" => out
)
}).collect::<PolarsResult<Vec<_>>>()?;
df!(
"jaccard" => out
)
})
.collect::<PolarsResult<Vec<_>>>()?;
accumulate_dataframes_vertical(dfs)
}

5 changes: 1 addition & 4 deletions example/extend_polars_python_dispatch/run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import polars as pl
from extend_polars import parallel_jaccard, lazy_parallel_jaccard

df = pl.DataFrame({
"list_a": [[1, 2, 3], [5, 5]],
"list_b": [[1, 2, 3, 8], [5, 1, 1]]
})
df = pl.DataFrame({"list_a": [[1, 2, 3], [5, 5]], "list_b": [[1, 2, 3, 8], [5, 1, 1]]})

print(df)
print(parallel_jaccard(df, "list_a", "list_b"))
Expand Down
Loading

0 comments on commit f3548ec

Please sign in to comment.