-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
I have a use-case where I want to be able to use some logic in rust to create custom class and us this as a kwarg in my plugin.
I took some time to figure out how to wire it together and I thought it might be interesting to have this available for others (I have written with a least one person in the polars Discord about it so there is some interest). And given that this is the defacto guide on plugins I thought this would be the right place.
Below is an example of how it can be used where I have used bincode for the serde.
The impl serde::Serialize and __setstate__ are not needed bot kept for completness.
use std::collections::HashMap;
use polars::prelude::*;
use pyo3::prelude::*;
use pyo3_polars::derive::polars_expr;
use serde::{Deserialize, Serialize};
/// Creates a sample lookup table for testing purposes.
/// # Returns
/// A `MyClass` instance containing a HashMap where each key `i` maps to
/// a vector `[i*3, i*3+1, i*3+2]`
#[pyfunction]
pub fn get_lookup_table(len: usize) -> MyClass {
let lookup_table: HashMap<i64, Vec<i64>> = (0..len as i64)
.map(|i| {
let start = i * 3;
(i, vec![start, start + 1, start + 2])
})
.collect();
MyClass { lookup_table }
}
/// A lookup table class exposed to Python via PyO3.
///
/// This class wraps a `HashMap<i64, Vec<i64>>` and is designed to be:
/// 1. Picklable in Python (via `__getstate__`/`__setstate__`)
/// 2. Passable through Polars plugin kwargs (via custom serde implementation)
///
/// # Serialization Strategy
///
/// Polars plugins use `serde_pickle` to deserialize kwargs on the Rust side.
/// When Python pickles the kwargs dict, objects with `__getstate__` return bytes.
/// The challenge is that `serde_pickle` expects to deserialize these bytes
/// back into the original struct.
///
/// To solve this, `MyClass` implements custom `Serialize`/`Deserialize` that:
/// - **Serialize**: Encodes the data as raw bytes using `serializer.serialize_bytes()`
/// - **Deserialize**: Expects raw bytes via `deserializer.deserialize_bytes()`
///
/// This matches how Polars' `DataFrame` handles serde (see `polars-core/src/serde/df.rs`),
/// allowing custom types to pass through the plugin kwargs system.
///
/// The `impl_polars_kwarg_class!` macro generates all the boilerplate for this.
#[pyclass(module = "mylib._internal")]
#[derive(Clone)]
pub struct MyClass {
pub(crate) lookup_table: HashMap<i64, Vec<i64>>,
}
impl serde::Serialize for MyClass {
fn serialize<S>(&self,serializer:S) -> Result<S::Ok,S::Error>where S:serde::Serializer,{
let bytes = bincode::serialize(&self.lookup_table).map_err(serde::ser::Error::custom)?;
serializer.serialize_bytes(&bytes)
}
}
impl <'de>serde::Deserialize<'de>for MyClass {
fn deserialize<D>(deserializer:D) -> Result<Self,D::Error>where D:serde::Deserializer<'de>,{
struct FieldVisitor;
impl <'de>serde::de::Visitor<'de>for FieldVisitor {
type Value = MyClass;
fn expecting(&self,formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str(concat!("a byte array containing bincode-serialized ",stringify!(MyClass)," data"))
}
fn visit_bytes<E>(self,v: &[u8]) -> Result<Self::Value,E>where E:serde::de::Error,{
let lookup_table:HashMap<i64,Vec<i64>> = bincode::deserialize(v).map_err(serde::de::Error::custom)?;
Ok(MyClass {
lookup_table
})
}
fn visit_byte_buf<E>(self,v:Vec<u8>) -> Result<Self::Value,E>where E:serde::de::Error,{
self.visit_bytes(&v)
}
}
deserializer.deserialize_bytes(FieldVisitor)
}
}
#[pyo3::pymethods]
impl MyClass {
pub fn new(lookup_table:HashMap<i64,Vec<i64>>) -> Self {
MyClass {
lookup_table
}
}
pub fn __setstate__(&mut self,state: &pyo3::Bound<'_,pyo3::types::PyBytes>,) -> pyo3::PyResult<()>{
self.lookup_table = bincode::deserialize(state.as_bytes()).unwrap();
Ok(())
}
pub fn __getstate__<'py>(&self,py:pyo3::Python<'py>,) -> pyo3::PyResult<pyo3::Bound<'py,pyo3::types::PyBytes>>{
Ok(pyo3::types::PyBytes::new(py, &bincode::serialize(&self.lookup_table).unwrap(),))
}
pub fn __getnewargs__(&self) -> pyo3::PyResult<(HashMap<i64,Vec<i64>> ,)>{
Ok((self.lookup_table.clone(),))
}
}
/// Kwargs struct for `my_class_function` Polars plugin.
///
/// This struct defines the keyword arguments that can be passed to the
/// `my_class_function` expression plugin. Polars serializes kwargs as pickle
/// on the Python side and deserializes them via serde_pickle on the Rust side.
///
/// The field names must match the Python kwargs exactly.
#[derive(Deserialize, Serialize)]
pub struct KwargsMyClass {
pub my_class: MyClass,
}
/// Polars expression plugin that uses a `MyClass` lookup table.
///
/// This function demonstrates passing a custom Rust struct through Polars
/// plugin kwargs. It retrieves a value from the lookup table and adds it
/// to each element of the input column.
///
/// # Arguments
/// * `inputs` - Array of input Series (expects one i64 column)
/// * `kwargs` - Contains the `MyClass` instance with the lookup table
///
/// # Returns
/// A new i64 Series where each value is `input[i] + lookup_table[0][0]`
#[polars_expr(output_type=Int64)]
fn my_class_function(inputs: &[Series], kwargs: KwargsMyClass) -> PolarsResult<Series> {
let input: &Int64Chunked = inputs[0].i64()?;
let variable_0 = *kwargs
.my_class
.lookup_table
.get(&0)
.ok_or(PolarsError::ComputeError(
"Key 0 not found in variable_0.".into(),
))?
.get(0)
.unwrap();
let value: ChunkedArray<Int64Type> = unary_elementwise_values(input, |k| k + variable_0);
Ok(value.into_series())
}With the following in src/lib.rs
#[pymodule]
fn _internal(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_class::<MyClass>()?;
m.add_function(wrap_pyfunction!(get_lookup_table, m)?)?;
Ok(())
}Metadata
Metadata
Assignees
Labels
No labels