Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,7 @@ mod tests {
.file_extension(".out");
let df = ctx.read_csv(&format!("{}/answers/q{}.out", path, n), options)?;
let df = df.select(
&get_answer_schema(n)
get_answer_schema(n)
.fields()
.iter()
.map(|field| {
Expand Down
2 changes: 1 addition & 1 deletion rust/datafusion/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async fn main() -> Result<()> {
let expr1 = pow.call(vec![col("a"), col("b")]);

// equivalent to `'SELECT pow(a, b), pow(a, b) AS pow1 FROM t'`
let df = df.select(&[
let df = df.select(vec![
expr,
// alias so that they have different column names
expr1.alias("pow1"),
Expand Down
6 changes: 3 additions & 3 deletions rust/datafusion/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ pub trait DataFrame: Send + Sync {
/// # fn main() -> Result<()> {
/// let mut ctx = ExecutionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
/// let df = df.select(&[col("a") * col("b"), col("c")])?;
/// let df = df.select(vec![col("a") * col("b"), col("c")])?;
/// # Ok(())
/// # }
/// ```
fn select(&self, expr: &[Expr]) -> Result<Arc<dyn DataFrame>>;
fn select(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>>;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the API change -- most of the rest of this PR is updating all the call sites. Note that DataFrame::filter takes in an owned Expr so taking an owned Vec is not that large of a departure


/// Filter a DataFrame to only include rows that match the specified filter expression.
///
Expand Down Expand Up @@ -157,7 +157,7 @@ pub trait DataFrame: Send + Sync {
/// let mut ctx = ExecutionContext::new();
/// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
/// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?
/// .select(&[
/// .select(vec![
/// col("a").alias("a2"),
/// col("b").alias("b2"),
/// col("c").alias("c2")])?;
Expand Down
10 changes: 5 additions & 5 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ mod tests {

let table = ctx.table("test")?;
let logical_plan = LogicalPlanBuilder::from(&table.to_logical_plan())
.project(&[col("c2")])?
.project(vec![col("c2")])?
.build()?;

let optimized_plan = ctx.optimize(&logical_plan)?;
Expand Down Expand Up @@ -886,7 +886,7 @@ mod tests {
assert_eq!(schema.field_with_name("c1")?.is_nullable(), false);

let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)?
.project(&[col("c1")])?
.project(vec![col("c1")])?
.build()?;

let plan = ctx.optimize(&plan)?;
Expand Down Expand Up @@ -917,7 +917,7 @@ mod tests {
)?]];

let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)?
.project(&[col("b")])?
.project(vec![col("b")])?
.build()?;
assert_fields_eq(&plan, vec!["b"]);

Expand Down Expand Up @@ -1548,7 +1548,7 @@ mod tests {

let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)?
.aggregate(&[col("c1")], &[sum(col("c2"))])?
.project(&[col("c1"), col("SUM(c2)").alias("total_salary")])?
.project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])?
.build()?;

let plan = ctx.optimize(&plan)?;
Expand Down Expand Up @@ -1773,7 +1773,7 @@ mod tests {
let t = ctx.table("t")?;

let plan = LogicalPlanBuilder::from(&t.to_logical_plan())
.project(&[
.project(vec![
col("a"),
col("b"),
ctx.udf("my_add")?.call(vec![col("a"), col("b")]),
Expand Down
8 changes: 4 additions & 4 deletions rust/datafusion/src/execution/dataframe_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ impl DataFrame for DataFrameImpl {
.map(|name| self.plan.schema().field_with_unqualified_name(name))
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = fields.iter().map(|f| col(f.name())).collect();
self.select(&expr)
self.select(expr)
}

/// Create a projection based on arbitrary expressions
fn select(&self, expr_list: &[Expr]) -> Result<Arc<dyn DataFrame>> {
fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.project(expr_list)?
.build()?;
Expand Down Expand Up @@ -197,7 +197,7 @@ mod tests {
fn select_expr() -> Result<()> {
// build plan using Table API
let t = test_table()?;
let t2 = t.select(&[col("c1"), col("c2"), col("c11")])?;
let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?;
let plan = t2.to_logical_plan();

// build query using SQL
Expand Down Expand Up @@ -315,7 +315,7 @@ mod tests {

let f = df.registry();

let df = df.select(&[f.udf("my_fn")?.call(vec![col("c12")])])?;
let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
let plan = df.to_logical_plan();

// build query using SQL
Expand Down
25 changes: 14 additions & 11 deletions rust/datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,19 @@ impl LogicalPlanBuilder {
/// This function errors under any of the following conditions:
/// * Two or more expressions have the same name
/// * An invalid expression is used (e.g. a `sort` expression)
pub fn project(&self, expr: &[Expr]) -> Result<Self> {
pub fn project(&self, expr: Vec<Expr>) -> Result<Self> {
let input_schema = self.plan.schema();
let mut projected_expr = vec![];
(0..expr.len()).for_each(|i| match &expr[i] {
Expr::Wildcard => {
(0..input_schema.fields().len())
.for_each(|i| projected_expr.push(col(input_schema.field(i).name())));
}
_ => projected_expr.push(expr[i].clone()),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is where the Expr::clone occurs and no longer does after the changes contemplated in this PR

});
for e in expr {
match e {
Expr::Wildcard => {
(0..input_schema.fields().len()).for_each(|i| {
projected_expr.push(col(input_schema.field(i).name()))
});
}
_ => projected_expr.push(e),
};
}

validate_unique_names("Projections", &projected_expr, input_schema)?;

Expand Down Expand Up @@ -352,7 +355,7 @@ mod tests {
Some(vec![0, 3]),
)?
.filter(col("state").eq(lit("CO")))?
.project(&[col("id")])?
.project(vec![col("id")])?
.build()?;

let expected = "Projection: #id\
Expand All @@ -372,7 +375,7 @@ mod tests {
Some(vec![3, 4]),
)?
.aggregate(&[col("state")], &[sum(col("salary")).alias("total_salary")])?
.project(&[col("state"), col("total_salary")])?
.project(vec![col("state"), col("total_salary")])?
.build()?;

let expected = "Projection: #state, #total_salary\
Expand Down Expand Up @@ -421,7 +424,7 @@ mod tests {
Some(vec![0, 3]),
)?
// two columns with the same name => error
.project(&[col("id"), col("first_name").alias("id")]);
.project(vec![col("id"), col("first_name").alias("id")]);

match plan {
Err(DataFusionError::Plan(e)) => {
Expand Down
4 changes: 2 additions & 2 deletions rust/datafusion/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ mod tests {
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
.project(&[col("id")])
.project(vec![col("id")])
.unwrap()
.build()
.unwrap()
Expand Down Expand Up @@ -1063,7 +1063,7 @@ mod tests {
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
.project(&[col("id")])
.project(vec![col("id")])
.unwrap()
.build()
.unwrap()
Expand Down
14 changes: 7 additions & 7 deletions rust/datafusion/src/optimizer/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ mod tests {
let plan = LogicalPlanBuilder::from(&table_scan)
.filter(col("b").eq(lit(true)))?
.filter(col("c").eq(lit(false)))?
.project(&[col("a")])?
.project(vec![col("a")])?
.build()?;

let expected = "\
Expand All @@ -488,7 +488,7 @@ mod tests {
.filter(col("b").not_eq(lit(true)))?
.filter(col("c").not_eq(lit(false)))?
.limit(1)?
.project(&[col("a")])?
.project(vec![col("a")])?
.build()?;

let expected = "\
Expand All @@ -507,7 +507,7 @@ mod tests {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true))))?
.project(&[col("a")])?
.project(vec![col("a")])?
.build()?;

let expected = "\
Expand All @@ -524,7 +524,7 @@ mod tests {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false))))?
.project(&[col("a")])?
.project(vec![col("a")])?
.build()?;

let expected = "\
Expand All @@ -541,7 +541,7 @@ mod tests {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.filter(col("b").eq(lit(false)).not())?
.project(&[col("a")])?
.project(vec![col("a")])?
.build()?;

let expected = "\
Expand All @@ -557,7 +557,7 @@ mod tests {
fn optimize_plan_support_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(&[col("a"), col("d"), col("b").eq(lit(false))])?
.project(vec![col("a"), col("d"), col("b").eq(lit(false))])?
.build()?;

let expected = "\
Expand All @@ -572,7 +572,7 @@ mod tests {
fn optimize_plan_support_aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(&[col("a"), col("c"), col("b")])?
.project(vec![col("a"), col("c"), col("b")])?
.aggregate(
&[col("a"), col("c")],
&[max(col("b").eq(lit(true))), min(col("b"))],
Expand Down
Loading