Skip to content

Commit 8b4d2a1

Browse files
committed
preliminary plumbing for unambiguous join
1 parent 963d60d commit 8b4d2a1

File tree

8 files changed

+367
-161
lines changed

8 files changed

+367
-161
lines changed

Cargo.lock

Lines changed: 106 additions & 154 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"]
3939
pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]}
4040
pyo3-log = "0.12.4"
4141
arrow = { version = "55.1.0", features = ["pyarrow"] }
42-
datafusion = { version = "48.0.0", features = ["avro", "unicode_expressions"] }
43-
datafusion-substrait = { version = "48.0.0", optional = true }
44-
datafusion-proto = { version = "48.0.0" }
45-
datafusion-ffi = { version = "48.0.0" }
42+
datafusion = { path="../datafusion/datafusion/core", features = ["avro", "unicode_expressions"] }
43+
datafusion-substrait = { path="../datafusion/datafusion/substrait", optional = true }
44+
datafusion-proto = { path="../datafusion/datafusion/proto" }
45+
datafusion-ffi = { path="../datafusion/datafusion/ffi" }
4646
prost = "0.13.1" # keep in line with `datafusion-substrait`
4747
uuid = { version = "1.16", features = ["v4"] }
4848
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }

python/datafusion/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
SqlType = common_internal.SqlType
3535
SqlView = common_internal.SqlView
3636
TableType = common_internal.TableType
37+
TableReference = common_internal.TableReference
3738
TableSource = common_internal.TableSource
3839
Constraints = common_internal.Constraints
3940

@@ -51,6 +52,7 @@
5152
"SqlTable",
5253
"SqlType",
5354
"SqlView",
55+
"TableReference",
5456
"TableSource",
5557
"TableType",
5658
]

python/datafusion/dataframe.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43+
from datafusion.common import TableReference
4344
from datafusion.expr import Expr, SortExpr, sort_or_default
4445
from datafusion.plan import ExecutionPlan, LogicalPlan
4546
from datafusion.record_batch import RecordBatchStream
@@ -1156,3 +1157,25 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> DataFrame:
11561157
- For columns not in subset, the original column is kept unchanged
11571158
"""
11581159
return DataFrame(self.df.fill_null(value, subset))
1160+
1161+
def fully_qualified_name(self, col: str) -> tuple[TableReference, pa.Field]:
1162+
"""Get the fully qualified name of a column in the DataFrame.
1163+
1164+
Args:
1165+
col: The column name to get the fully qualified name for.
1166+
1167+
Returns:
1168+
Fully qualified column name as a string.
1169+
"""
1170+
return self.df.fully_qualified_name(col)
1171+
1172+
def _drop(self, *qualified_columns: tuple[TableReference, pa.Field]) -> DataFrame:
1173+
"""Drop columns from the DataFrame using fully qualified names.
1174+
1175+
Args:
1176+
qualified_columns: Fully qualified column names to drop.
1177+
1178+
Returns:
1179+
DataFrame with specified columns removed.
1180+
"""
1181+
return DataFrame(self.df._drop(*qualified_columns))

python/tests/test_dataframe.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,21 @@ def test_drop(df):
272272
assert result.column(1) == pa.array([4, 5, 6])
273273

274274

275+
def test__drop(df):
276+
# TODO: implement deep copy for DataFrame?
277+
ctx = SessionContext()
278+
279+
# create a RecordBatch and a new DataFrame from it
280+
batch = pa.RecordBatch.from_arrays(
281+
[pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])],
282+
names=["a", "d", "e"],
283+
)
284+
285+
other_df = ctx.from_arrow(batch)
286+
df = df.join(other_df, on="a")._drop(df.fully_qualified_name("a"))
287+
assert len(df.schema().names) == len(other_df.schema().names) * 2 - 1
288+
289+
275290
def test_limit(df):
276291
df = df.limit(1)
277292

src/common.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub mod data_type;
2121
pub mod df_schema;
2222
pub mod function;
2323
pub mod schema;
24+
pub mod table_reference;
2425

2526
/// Initializes the `common` module to match the pattern of `datafusion-common` https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html
2627
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
@@ -39,5 +40,6 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
3940
m.add_class::<schema::PyTableType>()?;
4041
m.add_class::<schema::PyTableSource>()?;
4142
m.add_class::<schema::PyConstraints>()?;
43+
m.add_class::<table_reference::PyTableReference>()?;
4244
Ok(())
4345
}

src/common/table_reference.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::common::TableReference;
19+
use pyo3::prelude::*;
20+
21+
/// PyO3 requires that objects passed between Rust and Python implement the trait `PyClass`
22+
/// Since `TableReference` exists in another package we cannot make that happen here so we wrap
23+
/// `TableReference` as `PyTableReference` This exists solely to satisfy those constraints.
24+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25+
#[pyclass(name = "TableReference", module = "datafusion.common", subclass)]
26+
pub struct PyTableReference {
27+
pub table_reference: TableReference,
28+
}
29+
30+
impl PyTableReference {
31+
pub fn new(table_reference: TableReference) -> Self {
32+
Self { table_reference }
33+
}
34+
}
35+
36+
impl From<PyTableReference> for TableReference {
37+
fn from(py_table_reference: PyTableReference) -> TableReference {
38+
py_table_reference.table_reference
39+
}
40+
}
41+
42+
impl From<TableReference> for PyTableReference {
43+
fn from(table_reference: TableReference) -> PyTableReference {
44+
PyTableReference { table_reference }
45+
}
46+
}
47+
48+
impl std::fmt::Display for PyTableReference {
49+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50+
self.table_reference.fmt(f)
51+
}
52+
}
53+
54+
#[pymethods]
55+
impl PyTableReference {
56+
/// Create a bare (unqualified) table reference
57+
#[staticmethod]
58+
pub fn bare(table: &str) -> Self {
59+
Self::new(TableReference::bare(table))
60+
}
61+
62+
/// Create a partial (schema.table) table reference
63+
#[staticmethod]
64+
pub fn partial(schema: &str, table: &str) -> Self {
65+
Self::new(TableReference::partial(schema, table))
66+
}
67+
68+
/// Create a full (catalog.schema.table) table reference
69+
#[staticmethod]
70+
pub fn full(catalog: &str, schema: &str, table: &str) -> Self {
71+
Self::new(TableReference::full(catalog, schema, table))
72+
}
73+
74+
/// Get the table name
75+
pub fn table(&self) -> &str {
76+
self.table_reference.table()
77+
}
78+
79+
/// Get the schema name if present
80+
pub fn schema(&self) -> Option<&str> {
81+
self.table_reference.schema()
82+
}
83+
84+
/// Get the catalog name if present
85+
pub fn catalog(&self) -> Option<&str> {
86+
self.table_reference.catalog()
87+
}
88+
89+
/// Check if this is a bare table reference
90+
pub fn is_bare(&self) -> bool {
91+
matches!(self.table_reference, TableReference::Bare { .. })
92+
}
93+
94+
/// Check if this is a partial table reference
95+
pub fn is_partial(&self) -> bool {
96+
matches!(self.table_reference, TableReference::Partial { .. })
97+
}
98+
99+
/// Check if this is a full table reference
100+
pub fn is_full(&self) -> bool {
101+
matches!(self.table_reference, TableReference::Full { .. })
102+
}
103+
104+
fn __str__(&self) -> String {
105+
self.to_string()
106+
}
107+
108+
fn __repr__(&self) -> String {
109+
format!("TableReference('{self}')")
110+
}
111+
}

src/dataframe.rs

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ use std::sync::Arc;
2121

2222
use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
24+
use arrow::datatypes::Field;
2425
use arrow::error::ArrowError;
2526
use arrow::ffi::FFI_ArrowSchema;
2627
use arrow::ffi_stream::FFI_ArrowArrayStream;
27-
use arrow::pyarrow::FromPyArrow;
2828
use datafusion::arrow::datatypes::Schema;
29-
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
29+
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
3030
use datafusion::arrow::util::pretty;
31-
use datafusion::common::UnnestOptions;
31+
use datafusion::common::{TableReference, UnnestOptions};
3232
use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
3333
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3434
use datafusion::datasource::TableProvider;
@@ -45,6 +45,7 @@ use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
4545
use tokio::task::JoinHandle;
4646

4747
use crate::catalog::PyTable;
48+
use crate::common::table_reference::PyTableReference;
4849
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
4950
use crate::expr::sort_expr::to_sort_expressions;
5051
use crate::physical_plan::PyExecutionPlan;
@@ -58,6 +59,81 @@ use crate::{
5859
expr::{sort_expr::PySortExpr, PyExpr},
5960
};
6061

62+
/// A wrapper around Arc<str> that implements PyO3 traits for easier Python interop
63+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
64+
pub struct PyArcStr(Arc<str>);
65+
66+
impl PyArcStr {
67+
pub fn new(s: &str) -> Self {
68+
Self(Arc::from(s))
69+
}
70+
71+
pub fn as_str(&self) -> &str {
72+
&self.0
73+
}
74+
75+
pub fn into_arc(self) -> Arc<str> {
76+
self.0
77+
}
78+
}
79+
80+
impl From<Arc<str>> for PyArcStr {
81+
fn from(arc: Arc<str>) -> Self {
82+
Self(arc)
83+
}
84+
}
85+
86+
impl From<PyArcStr> for Arc<str> {
87+
fn from(py_arc: PyArcStr) -> Self {
88+
py_arc.0
89+
}
90+
}
91+
92+
impl From<&str> for PyArcStr {
93+
fn from(s: &str) -> Self {
94+
Self::new(s)
95+
}
96+
}
97+
98+
impl From<String> for PyArcStr {
99+
fn from(s: String) -> Self {
100+
Self(Arc::from(s))
101+
}
102+
}
103+
104+
impl std::fmt::Display for PyArcStr {
105+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106+
self.0.fmt(f)
107+
}
108+
}
109+
110+
impl<'py> pyo3::IntoPyObject<'py> for PyArcStr {
111+
type Target = pyo3::types::PyString;
112+
type Output = pyo3::Bound<'py, Self::Target>;
113+
type Error = std::convert::Infallible;
114+
115+
fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
116+
Ok(pyo3::types::PyString::new(py, &self.0))
117+
}
118+
}
119+
120+
impl<'py> pyo3::IntoPyObject<'py> for &PyArcStr {
121+
type Target = pyo3::types::PyString;
122+
type Output = pyo3::Bound<'py, Self::Target>;
123+
type Error = std::convert::Infallible;
124+
125+
fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
126+
Ok(pyo3::types::PyString::new(py, &self.0))
127+
}
128+
}
129+
130+
impl<'py> pyo3::FromPyObject<'py> for PyArcStr {
131+
fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
132+
let s: String = ob.extract()?;
133+
Ok(Self::new(&s))
134+
}
135+
}
136+
61137
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
62138
// - we have not decided on the table_provider approach yet
63139
// this is an interim implementation
@@ -428,6 +504,17 @@ impl PyDataFrame {
428504
PyArrowType(self.df.schema().into())
429505
}
430506

507+
fn fully_qualified_name(&self, col: &str) -> PyResult<(PyTableReference, PyArrowType<Field>)> {
508+
let result = self.df.schema().qualified_field_with_unqualified_name(col);
509+
match result {
510+
Ok(parts) => Ok((parts.0.unwrap().clone().into(), parts.1.clone().into())),
511+
Err(err) => Err(PyValueError::new_err(format!(
512+
"Error: {:?}",
513+
err.to_string()
514+
))),
515+
}
516+
}
517+
431518
/// Convert this DataFrame into a Table that can be used in register_table
432519
/// By convention, into_... methods consume self and return the new object.
433520
/// Disabling the clippy lint, so we can use &self
@@ -467,6 +554,20 @@ impl PyDataFrame {
467554
Ok(Self::new(df))
468555
}
469556

557+
#[pyo3(signature = (*_args))]
558+
fn _drop(
559+
&self,
560+
_args: Vec<(PyTableReference, PyArrowType<Field>)>,
561+
) -> PyDataFusionResult<Self> {
562+
// TODO need to finish plumbing through
563+
let cols = _args
564+
.iter()
565+
.map(|(table, s)| (Some(table.clone().into()), s.0.clone()))
566+
.collect::<Vec<(Option<TableReference>, Field)>>();
567+
let df = self.df.as_ref().clone().drop_qualified_columns(&cols)?;
568+
Ok(Self::new(df))
569+
}
570+
470571
fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
471572
let df = self.df.as_ref().clone().filter(predicate.into())?;
472573
Ok(Self::new(df))

0 commit comments

Comments
 (0)