Skip to content

Interested in adding a section on using rust classes in the kwargs? #75

@thomasfrederikhoeck

Description

@thomasfrederikhoeck

I have a use-case where I want to be able to use some logic in rust to create custom class and us this as a kwarg in my plugin.

I took some time to figure out how to wire it together and I thought it might be interesting to have this available for others (I have written with a least one person in the polars Discord about it so there is some interest). And given that this is the defacto guide on plugins I thought this would be the right place.

Below is an example of how it can be used where I have used bincode for the serde.

The impl serde::Serialize and __setstate__ are not needed bot kept for completness.

use std::collections::HashMap;

use polars::prelude::*;
use pyo3::prelude::*;
use pyo3_polars::derive::polars_expr;
use serde::{Deserialize, Serialize};

/// Creates a sample lookup table for testing purposes.
/// # Returns
/// A `MyClass` instance containing a HashMap where each key `i` maps to
/// a vector `[i*3, i*3+1, i*3+2]`

#[pyfunction]
pub fn get_lookup_table(len: usize) -> MyClass {
    let lookup_table: HashMap<i64, Vec<i64>> = (0..len as i64)
        .map(|i| {
            let start = i * 3;
            (i, vec![start, start + 1, start + 2])
        })
        .collect();
    MyClass { lookup_table }
}

/// A lookup table class exposed to Python via PyO3.
///
/// This class wraps a `HashMap<i64, Vec<i64>>` and is designed to be:
/// 1. Picklable in Python (via `__getstate__`/`__setstate__`)
/// 2. Passable through Polars plugin kwargs (via custom serde implementation)
///
/// # Serialization Strategy
///
/// Polars plugins use `serde_pickle` to deserialize kwargs on the Rust side.
/// When Python pickles the kwargs dict, objects with `__getstate__` return bytes.
/// The challenge is that `serde_pickle` expects to deserialize these bytes
/// back into the original struct.
///
/// To solve this, `MyClass` implements custom `Serialize`/`Deserialize` that:
/// - **Serialize**: Encodes the data as raw bytes using `serializer.serialize_bytes()`
/// - **Deserialize**: Expects raw bytes via `deserializer.deserialize_bytes()`
///
/// This matches how Polars' `DataFrame` handles serde (see `polars-core/src/serde/df.rs`),
/// allowing custom types to pass through the plugin kwargs system.
///
/// The `impl_polars_kwarg_class!` macro generates all the boilerplate for this.
#[pyclass(module = "mylib._internal")]
#[derive(Clone)]
pub struct MyClass {
    pub(crate) lookup_table: HashMap<i64, Vec<i64>>,
}

impl serde::Serialize for MyClass {
    fn serialize<S>(&self,serializer:S) -> Result<S::Ok,S::Error>where S:serde::Serializer,{
        let bytes = bincode::serialize(&self.lookup_table).map_err(serde::ser::Error::custom)?;
        serializer.serialize_bytes(&bytes)
    }

    }
impl <'de>serde::Deserialize<'de>for MyClass {
    fn deserialize<D>(deserializer:D) -> Result<Self,D::Error>where D:serde::Deserializer<'de>,{
        struct FieldVisitor;
        
        impl <'de>serde::de::Visitor<'de>for FieldVisitor {
            type Value = MyClass;
            fn expecting(&self,formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
                formatter.write_str(concat!("a byte array containing bincode-serialized ",stringify!(MyClass)," data"))
            }
            fn visit_bytes<E>(self,v: &[u8]) -> Result<Self::Value,E>where E:serde::de::Error,{
                let lookup_table:HashMap<i64,Vec<i64>>  = bincode::deserialize(v).map_err(serde::de::Error::custom)?;
                Ok(MyClass {
                    lookup_table
                })
            }
            fn visit_byte_buf<E>(self,v:Vec<u8>) -> Result<Self::Value,E>where E:serde::de::Error,{
                self.visit_bytes(&v)
            }
        
            }
        deserializer.deserialize_bytes(FieldVisitor)
    }

    }
#[pyo3::pymethods]
impl MyClass {
    pub fn new(lookup_table:HashMap<i64,Vec<i64>>) -> Self {
        MyClass {
            lookup_table
        }
    }
    pub fn __setstate__(&mut self,state: &pyo3::Bound<'_,pyo3::types::PyBytes>,) -> pyo3::PyResult<()>{
        self.lookup_table = bincode::deserialize(state.as_bytes()).unwrap();
        Ok(())
    }

    pub fn __getstate__<'py>(&self,py:pyo3::Python<'py>,) -> pyo3::PyResult<pyo3::Bound<'py,pyo3::types::PyBytes>>{
        Ok(pyo3::types::PyBytes::new(py, &bincode::serialize(&self.lookup_table).unwrap(),))
    }
    pub fn __getnewargs__(&self) -> pyo3::PyResult<(HashMap<i64,Vec<i64>> ,)>{
        Ok((self.lookup_table.clone(),))
    }

    }

/// Kwargs struct for `my_class_function` Polars plugin.
///
/// This struct defines the keyword arguments that can be passed to the
/// `my_class_function` expression plugin. Polars serializes kwargs as pickle
/// on the Python side and deserializes them via serde_pickle on the Rust side.
///
/// The field names must match the Python kwargs exactly.
#[derive(Deserialize, Serialize)]
pub struct KwargsMyClass {
    pub my_class: MyClass,
}


/// Polars expression plugin that uses a `MyClass` lookup table.
///
/// This function demonstrates passing a custom Rust struct through Polars
/// plugin kwargs. It retrieves a value from the lookup table and adds it
/// to each element of the input column.
///
/// # Arguments
/// * `inputs` - Array of input Series (expects one i64 column)
/// * `kwargs` - Contains the `MyClass` instance with the lookup table
///
/// # Returns
/// A new i64 Series where each value is `input[i] + lookup_table[0][0]`
#[polars_expr(output_type=Int64)]
fn my_class_function(inputs: &[Series], kwargs: KwargsMyClass) -> PolarsResult<Series> {
    let input: &Int64Chunked = inputs[0].i64()?;
    let variable_0 = *kwargs
        .my_class
        .lookup_table
        .get(&0)
        .ok_or(PolarsError::ComputeError(
            "Key 0 not found in variable_0.".into(),
        ))?
        .get(0)
        .unwrap();

    let value: ChunkedArray<Int64Type> = unary_elementwise_values(input, |k| k + variable_0);

    Ok(value.into_series())
}

With the following in src/lib.rs

#[pymodule]
fn _internal(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
    m.add_class::<MyClass>()?;
    m.add_function(wrap_pyfunction!(get_lookup_table, m)?)?;
    Ok(())
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions