Skip to content

Commit

Permalink
Add interface to directly serialize Substrait plans to Python Bytes. (#…
Browse files Browse the repository at this point in the history
…344)

Clean up existing substrait bindings to return bytes instead of List[int].
  • Loading branch information
kylebrooks-8451 authored Apr 25, 2023
1 parent 14f4840 commit b6c2115
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ datafusion-expr = "23.0.0"
datafusion-optimizer = "23.0.0"
datafusion-sql = "23.0.0"
datafusion-substrait = "23.0.0"
prost = "0.11"
prost-types = "0.11"
uuid = { version = "1.2", features = ["v4"] }
mimalloc = { version = "0.1", optional = true, default-features = false }
async-trait = "0.1"
Expand Down
2 changes: 2 additions & 0 deletions datafusion/tests/test_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def test_substrait_serialization(ctx):
substrait_plan = ss.substrait.serde.serialize_to_plan(
"SELECT * FROM t", ctx
)
substrait_bytes = substrait_plan.encode()
assert type(substrait_bytes) is bytes
substrait_bytes = ss.substrait.serde.serialize_bytes(
"SELECT * FROM t", ctx
)
Expand Down
7 changes: 6 additions & 1 deletion examples/substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@
)
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>

# Encode it to bytes
substrait_bytes = substrait_plan.encode()
# type(substrait_bytes) -> <class 'bytes'>, at this point the bytes can be distributed to file, network, etc safely
# where they could subsequently be deserialized on the receiving end.

# Alternative serialization approaches
# type(substrait_bytes) -> <class 'list'>, at this point the bytes can be distributed to file, network, etc safely
# type(substrait_bytes) -> <class 'bytes'>, at this point the bytes can be distributed to file, network, etc safely
# where they could subsequently be deserialized on the receiving end.
substrait_bytes = ss.substrait.serde.serialize_bytes(
"SELECT * FROM aggregate_test_data", ctx
Expand Down
3 changes: 3 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::fmt::Debug;

use datafusion::arrow::error::ArrowError;
use datafusion::error::DataFusionError as InnerDataFusionError;
use prost::EncodeError;
use pyo3::{exceptions::PyException, PyErr};

pub type Result<T> = std::result::Result<T, DataFusionError>;
Expand All @@ -31,6 +32,7 @@ pub enum DataFusionError {
ArrowError(ArrowError),
Common(String),
PythonError(PyErr),
EncodeError(EncodeError),
}

impl fmt::Display for DataFusionError {
Expand All @@ -40,6 +42,7 @@ impl fmt::Display for DataFusionError {
DataFusionError::ArrowError(e) => write!(f, "Arrow error: {e:?}"),
DataFusionError::PythonError(e) => write!(f, "Python error {e:?}"),
DataFusionError::Common(e) => write!(f, "{e}"),
DataFusionError::EncodeError(e) => write!(f, "Failed to encode substrait plan: {e}"),
}
}
}
Expand Down
24 changes: 20 additions & 4 deletions src/substrait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use pyo3::prelude::*;
use pyo3::{prelude::*, types::PyBytes};

use crate::context::PySessionContext;
use crate::errors::{py_datafusion_err, DataFusionError};
Expand All @@ -25,13 +25,25 @@ use crate::utils::wait_for_future;
use datafusion_substrait::logical_plan::{consumer, producer};
use datafusion_substrait::serializer;
use datafusion_substrait::substrait::proto::Plan;
use prost::Message;

#[pyclass(name = "plan", module = "datafusion.substrait", subclass, unsendable)]
#[derive(Debug, Clone)]
pub(crate) struct PyPlan {
pub(crate) plan: Plan,
}

#[pymethods]
impl PyPlan {
fn encode(&self, py: Python) -> PyResult<PyObject> {
let mut proto_bytes = Vec::<u8>::new();
self.plan
.encode(&mut proto_bytes)
.map_err(|e| DataFusionError::EncodeError(e))?;
Ok(PyBytes::new(py, &proto_bytes).into())
}
}

impl From<PyPlan> for Plan {
fn from(plan: PyPlan) -> Plan {
plan.plan
Expand Down Expand Up @@ -63,16 +75,19 @@ impl PySubstraitSerializer {
#[staticmethod]
pub fn serialize_to_plan(sql: &str, ctx: PySessionContext, py: Python) -> PyResult<PyPlan> {
match PySubstraitSerializer::serialize_bytes(sql, ctx, py) {
Ok(proto_bytes) => PySubstraitSerializer::deserialize_bytes(proto_bytes, py),
Ok(proto_bytes) => {
let proto_bytes: &PyBytes = proto_bytes.as_ref(py).downcast().unwrap();
PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py)
}
Err(e) => Err(py_datafusion_err(e)),
}
}

#[staticmethod]
pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) -> PyResult<Vec<u8>> {
pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) -> PyResult<PyObject> {
let proto_bytes: Vec<u8> = wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))
.map_err(DataFusionError::from)?;
Ok(proto_bytes)
Ok(PyBytes::new(py, &proto_bytes).into())
}

#[staticmethod]
Expand Down Expand Up @@ -136,6 +151,7 @@ impl PySubstraitConsumer {
}

pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_class::<PyPlan>()?;
m.add_class::<PySubstraitConsumer>()?;
m.add_class::<PySubstraitProducer>()?;
m.add_class::<PySubstraitSerializer>()?;
Expand Down

0 comments on commit b6c2115

Please sign in to comment.