Skip to content

Commit

Permalink
add logical type example (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 17, 2023
1 parent f3548ec commit 42c5b0c
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 42 deletions.
16 changes: 11 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ members = [
]

[workspace.dependencies]
polars = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false }
polars-core = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false }
polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false }
polars-plan = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-feautres = false }
polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false }
polars = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-features = false }
polars-core = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-features = false }
polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-features = false }
polars-plan = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-feautres = false }
polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-features = false }

#polars = { path = "../polars/crates/polars", version = "0.33.2", default-features = false }
#polars-core = { path = "../polars/crates/polars-core", version = "0.33.2", default-features = false }
#polars-ffi = { path = "../polars/crates/polars-ffi", version = "0.33.2", default-features = false }
#polars-plan = { path = "../polars/crates/polars-plan", version = "0.33.2", default-feautres = false }
#polars-lazy = { path = "../polars/crates/polars-lazy", version = "0.33.2", default-features = false }
2 changes: 1 addition & 1 deletion example/derive_expression/expression_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ crate-type = ["cdylib"]

[dependencies]
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }
polars = { workspace = true, features = ["fmt"], default-features = false }
polars = { workspace = true, features = ["fmt", "dtype-date"], default-features = false }
polars-plan = { workspace = true, default-features = false }
pyo3 = { version = "0.20.0", features = ["extension-module"] }
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,16 @@ def haversine(
is_elementwise=True,
cast_to_supertypes=True,
)

@pl.api.register_expr_namespace("date_util")
class DateUtil:
def __init__(self, expr: pl.Expr):
self._expr = expr


def is_leap_year(self) -> pl.Expr:
return self._expr._register_plugin(
lib=lib,
symbol="is_leap_year",
is_elementwise=True,
)
29 changes: 22 additions & 7 deletions example/derive_expression/expression_lib/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use polars::prelude::*;
use polars_plan::dsl::FieldsMapper;
use pyo3_polars::derive::{polars_expr, DefaultKwargs};
use pyo3_polars::derive::polars_expr;
use serde::Deserialize;
use std::fmt::Write;

Expand All @@ -11,21 +11,21 @@ fn pig_latin_str(value: &str, output: &mut String) {
}

#[polars_expr(output_type=Utf8)]
fn pig_latinnify(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
fn pig_latinnify(inputs: &[Series]) -> PolarsResult<Series> {
let ca = inputs[0].utf8()?;
let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str);
Ok(out.into_series())
}

#[polars_expr(output_type=Float64)]
fn jaccard_similarity(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
fn jaccard_similarity(inputs: &[Series]) -> PolarsResult<Series> {
let a = inputs[0].list()?;
let b = inputs[1].list()?;
crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series())
}

#[polars_expr(output_type=Float64)]
fn hamming_distance(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
fn hamming_distance(inputs: &[Series]) -> PolarsResult<Series> {
let a = inputs[0].utf8()?;
let b = inputs[1].utf8()?;
let out: UInt32Chunked =
Expand All @@ -38,7 +38,7 @@ fn haversine_output(input_fields: &[Field]) -> PolarsResult<Field> {
}

#[polars_expr(type_func=haversine_output)]
fn haversine(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
fn haversine(inputs: &[Series]) -> PolarsResult<Series> {
let out = match inputs[0].dtype() {
DataType::Float32 => {
let start_lat = inputs[0].f32().unwrap();
Expand Down Expand Up @@ -72,10 +72,12 @@ pub struct MyKwargs {
boolean_arg: bool,
}

/// If you want to accept `kwargs`. You define a `kwargs` argument
/// on the second position in you plugin. You can provide any custom struct that is deserializable
/// with the pickle protocol (on the rust side).
#[polars_expr(output_type=Utf8)]
fn append_kwargs(input: &[Series], kwargs: Option<MyKwargs>) -> PolarsResult<Series> {
fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult<Series> {
let input = &input[0];
let kwargs = kwargs.unwrap();
let input = input.cast(&DataType::Utf8)?;
let ca = input.utf8().unwrap();

Expand All @@ -90,3 +92,16 @@ fn append_kwargs(input: &[Series], kwargs: Option<MyKwargs>) -> PolarsResult<Ser
})
.into_series())
}

#[polars_expr(output_type=Boolean)]
fn is_leap_year(input: &[Series]) -> PolarsResult<Series> {
let input = &input[0];
let ca = input.date()?;

let out: BooleanChunked = ca
.as_date_iter()
.map(|opt_dt| opt_dt.map(|dt| dt.leap_year()))
.collect_ca(ca.name());

Ok(out.into_series())
}
3 changes: 3 additions & 0 deletions example/derive_expression/expression_lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
mod distances;
mod expressions;

#[cfg(target_os = "linux")]
use jemallocator::Jemalloc;

#[global_allocator]
#[cfg(target_os = "linux")]
static ALLOC: Jemalloc = Jemalloc;
3 changes: 3 additions & 0 deletions example/derive_expression/run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import polars as pl
from expression_lib import Language, Distance
from datetime import date

df = pl.DataFrame(
{
"names": ["Richard", "Alice", "Bob"],
"moons": ["full", "half", "red"],
"dates": [date(2023, 1, 1), date(2024, 1, 1), date(2025, 1, 1)],
"dist_a": [[12, 32, 1], [], [1, -2]],
"dist_b": [[-12, 1], [43], [876, -45, 9]],
"floats": [5.6, -1245.8, 242.224],
Expand All @@ -18,6 +20,7 @@
hamming_dist=pl.col("names").dist.hamming_distance("pig_latin"),
jaccard_sim=pl.col("dist_a").dist.jaccard_similarity("dist_b"),
haversine=pl.col("floats").dist.haversine("floats", "floats", "floats", "floats"),
leap_year=pl.col("dates").date_util.is_leap_year(),
appended_args=pl.col("names").language.append_args(
float_arg=11.234,
integer_arg=93,
Expand Down
101 changes: 72 additions & 29 deletions pyo3-polars-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod keywords;
use proc_macro::TokenStream;
use quote::quote;
use std::sync::atomic::{AtomicBool, Ordering};
use syn::parse_macro_input;
use syn::{parse_macro_input, FnArg};

static INIT: AtomicBool = AtomicBool::new(false);

Expand All @@ -21,10 +21,79 @@ fn insert_error_function() -> proc_macro2::TokenStream {
}
}

fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
quote!(

let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);

let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
Ok(value) => value,
Err(err) => {
pyo3_polars::derive::_update_last_error(err);
return;
}
};

// define the function
#ast

// call the function
let result: PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, kwargs);

)
}

fn quote_call_no_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
quote!(
// define the function
#ast
// call the function
let result: PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs);
)
}

fn quote_process_results() -> proc_macro2::TokenStream {
quote!(match result {
Ok(out) => {
// Update return value.
*return_value = polars_ffi::export_series(&out);
}
Err(err) => {
// Set latest error, but leave return value in empty state.
pyo3_polars::derive::_update_last_error(err);
}
})
}

fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
// count how often the user define a kwargs argument.
let n_kwargs = ast
.sig
.inputs
.iter()
.filter(|fn_arg| {
if let FnArg::Typed(pat) = fn_arg {
if let syn::Pat::Ident(pat) = pat.pat.as_ref() {
pat.ident.to_string() == "kwargs"
} else {
false
}
} else {
true
}
})
.count();

let fn_name = &ast.sig.ident;
let error_msg_fn = insert_error_function();

let quote_call = match n_kwargs {
0 => quote_call_no_kwargs(&ast, fn_name),
1 => quote_call_kwargs(&ast, fn_name),
_ => unreachable!(), // arguments are unique
};
let quote_process_result = quote_process_results();

quote!(
use pyo3_polars::export::*;

Expand All @@ -41,35 +110,9 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
) {
let inputs = polars_ffi::import_series_buffer(e, input_len).unwrap();

let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);

let kwargs = if kwargs.is_empty() {
::std::option::Option::None
} else {
match pyo3_polars::derive::_parse_kwargs(kwargs) {
Ok(value) => Some(value),
Err(err) => {
pyo3_polars::derive::_update_last_error(err);
return;
}
}
};

// define the function
#ast
#quote_call

// call the function
let result: PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, kwargs);
match result {
Ok(out) => {
// Update return value.
*return_value = polars_ffi::export_series(&out);
},
Err(err) => {
// Set latest error, but leave return value in empty state.
pyo3_polars::derive::_update_last_error(err);
}
}
#quote_process_result
}
)
}
Expand Down

0 comments on commit 42c5b0c

Please sign in to comment.