Skip to content

Commit

Permalink
feat: support kwargs for field functions (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 24, 2024
1 parent e39357c commit 870da1e
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 21 deletions.
2 changes: 1 addition & 1 deletion example/derive_expression/expression_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "expression_lib"
crate-type = ["cdylib"]

[dependencies]
polars = { workspace = true, features = ["fmt", "dtype-date"], default-features = false }
polars = { workspace = true, features = ["fmt", "dtype-date", "timezones"], default-features = false }
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
rayon = "1.7.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def is_leap_year(self) -> pl.Expr:
is_elementwise=True,
)

# Note that this already exists in Polars. It is just for explanatory
# purposes.
def change_time_zone(self, tz: str = "Europe/Amsterdam") -> pl.Expr:
return self._expr.register_plugin(
lib=lib, symbol="change_time_zone", is_elementwise=True, kwargs={"tz": tz}
)


@pl.api.register_expr_namespace("panic")
class Panic:
def __init__(self, expr: pl.Expr):
Expand Down
27 changes: 26 additions & 1 deletion example/derive_expression/expression_lib/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ fn pig_latinnify_with_paralellism(
})
.collect();

Ok(StringChunked::from_chunk_iter(ca.name(), chunks.into_iter().flatten()).into_series())
Ok(
StringChunked::from_chunk_iter(ca.name(), chunks.into_iter().flatten())
.into_series(),
)
})
}
}
Expand Down Expand Up @@ -180,3 +183,25 @@ fn is_leap_year(input: &[Series]) -> PolarsResult<Series> {
fn panic(_input: &[Series]) -> PolarsResult<Series> {
todo!()
}

#[derive(Deserialize)]
struct TimeZone {
tz: String,
}

fn convert_timezone(input_fields: &[Field], kwargs: TimeZone) -> PolarsResult<Field> {
FieldsMapper::new(input_fields).try_map_dtype(|dtype| match dtype {
DataType::Datetime(tu, _) => Ok(DataType::Datetime(*tu, Some(kwargs.tz.clone()))),
_ => polars_bail!(ComputeError: "expected datetime"),
})
}

#[polars_expr(output_type_func_with_kwargs=convert_timezone)]
fn change_time_zone(input: &[Series], kwargs: TimeZone) -> PolarsResult<Series> {
let input = &input[0];
let ca = input.datetime()?;

ca.clone()
.convert_time_zone(kwargs.tz)
.map(|ca| ca.into_series())
}
10 changes: 5 additions & 5 deletions example/derive_expression/run.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import polars as pl
from expression_lib import *
from datetime import date
from datetime import date, datetime, timezone

df = pl.DataFrame(
{
"names": ["Richard", "Alice", "Bob"],
"moons": ["full", "half", "red"],
"dates": [date(2023, 1, 1), date(2024, 1, 1), date(2025, 1, 1)],
"datetime": [datetime.now(tz=timezone.utc)] * 3,
"dist_a": [[12, 32, 1], [], [1, -2]],
"dist_b": [[-12, 1], [43], [876, -45, 9]],
"floats": [5.6, -1245.8, 242.224],
Expand All @@ -22,6 +23,7 @@
jaccard_sim=pl.col("dist_a").dist.jaccard_similarity("dist_b"),
haversine=pl.col("floats").dist.haversine("floats", "floats", "floats", "floats"),
leap_year=pl.col("dates").date_util.is_leap_year(),
new_tz=pl.col("datetime").date_util.change_time_zone(),
appended_args=pl.col("names").language.append_args(
float_arg=11.234,
integer_arg=93,
Expand All @@ -48,10 +50,8 @@


try:
out.with_columns(
pl.col("names").panic.panic()
)
out.with_columns(pl.col("names").panic.panic())
except pl.ComputeError as e:
assert "the plugin panicked" in str(e)

print("finished")
print("finished")
6 changes: 6 additions & 0 deletions pyo3-polars-derive/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ impl<K: Parse, V: Parse> Parse for KeyWordAttribute<K, V> {

pub type OutputAttribute = KeyWordAttribute<keywords::output_type, Ident>;
pub type OutputFuncAttribute = KeyWordAttribute<keywords::output_type_func, Ident>;
pub type OutputFuncAttributeWithKwargs =
KeyWordAttribute<keywords::output_type_func_with_kwargs, Ident>;

#[derive(Default, Debug)]
pub struct ExprsFunctionOptions {
pub output_dtype: Option<Ident>,
pub output_type_fn: Option<Ident>,
pub output_type_fn_kwargs: Option<Ident>,
}

impl Parse for ExprsFunctionOptions {
Expand All @@ -41,6 +44,9 @@ impl Parse for ExprsFunctionOptions {
} else if lookahead.peek(keywords::output_type_func) {
let attr = input.parse::<OutputFuncAttribute>()?;
options.output_type_fn = Some(attr.value)
} else if lookahead.peek(keywords::output_type_func_with_kwargs) {
let attr = input.parse::<OutputFuncAttributeWithKwargs>()?;
options.output_type_fn_kwargs = Some(attr.value)
} else {
panic!("didn't recognize attribute")
}
Expand Down
1 change: 1 addition & 0 deletions pyo3-polars-derive/src/keywords.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
syn::custom_keyword!(output_type);
syn::custom_keyword!(output_type_func);
syn::custom_keyword!(output_type_func_with_kwargs);
54 changes: 40 additions & 14 deletions pyo3-polars-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,26 @@ fn insert_error_function() -> proc_macro2::TokenStream {
}
}

fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
fn quote_get_kwargs() -> proc_macro2::TokenStream {
quote!(
let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);
let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);

let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
Ok(value) => value,
Err(err) => {
pyo3_polars::derive::_update_last_error(err);
return;
}
};
let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
Ok(value) => value,
Err(err) => {
pyo3_polars::derive::_update_last_error(err);
return;
}
};

)
}

fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
let kwargs = quote_get_kwargs();
quote!(
// parse the kwargs and assign to `let kwargs`
#kwargs

// define the function
#ast
Expand Down Expand Up @@ -184,7 +193,7 @@ fn get_expression_function_name(fn_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(&format!("_polars_plugin_{}", fn_name), fn_name.span())
}

fn get_inputs() -> proc_macro2::TokenStream {
fn quote_get_inputs() -> proc_macro2::TokenStream {
quote!(
let inputs = std::slice::from_raw_parts(field, len);
let inputs = inputs.iter().map(|field| {
Expand All @@ -198,21 +207,36 @@ fn get_inputs() -> proc_macro2::TokenStream {
fn create_field_function(
fn_name: &syn::Ident,
dtype_fn_name: &syn::Ident,
kwargs: bool,
) -> proc_macro2::TokenStream {
let map_field_name = get_field_function_name(fn_name);
let inputs = get_inputs();
let inputs = quote_get_inputs();

let call_fn = if kwargs {
let kwargs = quote_get_kwargs();
quote! (
#kwargs
let result = #dtype_fn_name(&inputs, kwargs);
)
} else {
quote!(
let result = #dtype_fn_name(&inputs);
)
};

quote! (
#[no_mangle]
pub unsafe extern "C" fn #map_field_name(
field: *mut polars_core::export::arrow::ffi::ArrowSchema,
len: usize,
return_value: *mut polars_core::export::arrow::ffi::ArrowSchema,
kwargs_ptr: *const u8,
kwargs_len: usize,
) {
let panic_result = std::panic::catch_unwind(move || {
#inputs;

let result = #dtype_fn_name(&inputs);
#call_fn;

match result {
Ok(out) => {
Expand All @@ -239,7 +263,7 @@ fn create_field_function_from_with_dtype(
dtype: syn::Ident,
) -> proc_macro2::TokenStream {
let map_field_name = get_field_function_name(fn_name);
let inputs = get_inputs();
let inputs = quote_get_inputs();

quote! (
#[no_mangle]
Expand All @@ -265,7 +289,9 @@ pub fn polars_expr(attr: TokenStream, input: TokenStream) -> TokenStream {

let options = parse_macro_input!(attr as attr::ExprsFunctionOptions);
let expanded_field_fn = if let Some(fn_name) = options.output_type_fn {
create_field_function(&ast.sig.ident, &fn_name)
create_field_function(&ast.sig.ident, &fn_name, false)
} else if let Some(fn_name) = options.output_type_fn_kwargs {
create_field_function(&ast.sig.ident, &fn_name, true)
} else if let Some(dtype) = options.output_dtype {
create_field_function_from_with_dtype(&ast.sig.ident, dtype)
} else {
Expand Down

0 comments on commit 870da1e

Please sign in to comment.