Skip to content

Commit b94ed82

Browse files
committed
Add documentation for prepared parameters + make it eaiser to use
1 parent b6f87ed commit b94ed82

File tree

6 files changed

+172
-70
lines changed

6 files changed

+172
-70
lines changed

datafusion/core/src/dataframe.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,42 @@ impl DataFrame {
12181218
Ok(DataFrame::new(self.session_state, project_plan))
12191219
}
12201220

1221-
/// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values
1221+
/// Replace all parameters in logical plan with the specified
1222+
/// values, in preparation for execution.
1223+
///
1224+
/// # Example
1225+
///
1226+
/// ```
1227+
/// use datafusion::prelude::*;
1228+
/// # use datafusion::{error::Result, assert_batches_eq};
1229+
/// # #[tokio::main]
1230+
/// # async fn main() -> Result<()> {
1231+
/// # use datafusion_common::ScalarValue;
1232+
/// let mut ctx = SessionContext::new();
1233+
/// # ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
1234+
/// let results = ctx
1235+
/// .sql("SELECT a FROM example WHERE b = $1")
1236+
/// .await?
1237+
/// // replace $1 with value 2
1238+
/// .with_param_values(vec![
1239+
/// // value at index 0 --> $1
1240+
/// ScalarValue::from(2i64)
1241+
/// ])?
1242+
/// .collect()
1243+
/// .await?;
1244+
/// assert_batches_eq!(
1245+
/// &[
1246+
/// "+---+",
1247+
/// "| a |",
1248+
/// "+---+",
1249+
/// "| 1 |",
1250+
/// "+---+",
1251+
/// ],
1252+
/// &results
1253+
/// );
1254+
/// # Ok(())
1255+
/// # }
1256+
/// ```
12221257
pub fn with_param_values(self, param_values: Vec<ScalarValue>) -> Result<Self> {
12231258
let plan = self.plan.with_param_values(param_values)?;
12241259
Ok(Self::new(self.session_state, plan))

datafusion/expr/src/expr.rs

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
//! Expr module contains core type definition for `Expr`.
1919
20-
use crate::aggregate_function;
2120
use crate::built_in_function;
2221
use crate::expr_fn::binary_expr;
2322
use crate::logical_plan::Subquery;
@@ -26,8 +25,10 @@ use crate::utils::{expr_to_columns, find_out_reference_exprs};
2625
use crate::window_frame;
2726
use crate::window_function;
2827
use crate::Operator;
28+
use crate::{aggregate_function, ExprSchemable};
2929
use arrow::datatypes::DataType;
30-
use datafusion_common::internal_err;
30+
use datafusion_common::tree_node::{Transformed, TreeNode};
31+
use datafusion_common::{internal_err, DFSchema};
3132
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
3233
use std::collections::HashSet;
3334
use std::fmt;
@@ -599,10 +600,13 @@ impl InSubquery {
599600
}
600601
}
601602

602-
/// Placeholder
603+
/// Placeholder, representing bind parameter values such as `$1`.
604+
///
605+
/// The type of these parameters is inferred using [`Expr::infer_placeholder_types`]
606+
/// or can be specified directly using `PREPARE` statements.
603607
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
604608
pub struct Placeholder {
605-
/// The identifier of the parameter (e.g, $1 or $foo)
609+
/// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo'`)
606610
pub id: String,
607611
/// The type the parameter will be filled in with
608612
pub data_type: Option<DataType>,
@@ -1030,6 +1034,49 @@ impl Expr {
10301034
pub fn contains_outer(&self) -> bool {
10311035
!find_out_reference_exprs(self).is_empty()
10321036
}
1037+
1038+
/// Find all [`Expr::Placeholder`] in anthis, and try
1039+
/// to infer their [`DataType`] from the context of their use.
1040+
pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<Expr> {
1041+
self.transform(&|mut expr| {
1042+
// Default to assuming the arguments are the same type
1043+
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
1044+
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
1045+
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
1046+
};
1047+
if let Expr::Between(Between {
1048+
expr,
1049+
negated: _,
1050+
low,
1051+
high,
1052+
}) = &mut expr
1053+
{
1054+
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
1055+
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
1056+
}
1057+
Ok(Transformed::Yes(expr))
1058+
})
1059+
}
1060+
}
1061+
1062+
// modifies expr if it is a placeholder with datatype of right
1063+
fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> {
1064+
if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
1065+
if data_type.is_none() {
1066+
let other_dt = other.get_type(schema);
1067+
match other_dt {
1068+
Err(e) => {
1069+
Err(e.context(format!(
1070+
"Can not find type of {other} needed to infer type of {expr}"
1071+
)))?;
1072+
}
1073+
Ok(dt) => {
1074+
*data_type = Some(dt);
1075+
}
1076+
}
1077+
};
1078+
}
1079+
Ok(())
10331080
}
10341081

10351082
#[macro_export]

datafusion/expr/src/expr_fn.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use crate::expr::{
2121
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22-
ScalarFunction, TryCast,
22+
Placeholder, ScalarFunction, TryCast,
2323
};
2424
use crate::function::PartitionEvaluatorFactory;
2525
use crate::WindowUDF;
@@ -80,6 +80,24 @@ pub fn ident(name: impl Into<String>) -> Expr {
8080
Expr::Column(Column::from_name(name))
8181
}
8282

83+
/// Create placeholder value that will be filled in (such as `$1`)
84+
///
85+
/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
86+
///
87+
/// # Example
88+
///
89+
/// ```rust
90+
/// # use datafusion_expr::{placeholder};
91+
/// let p = placeholder("$0"); // $0, refers to parameter 1
92+
/// assert_eq!(p.to_string(), "$0")
93+
/// ```
94+
pub fn placeholder(id: impl Into<String>) -> Expr {
95+
Expr::Placeholder(Placeholder {
96+
id: id.into(),
97+
data_type: None,
98+
})
99+
}
100+
83101
/// Return a new expression `left <op> right`
84102
pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
85103
Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -928,8 +928,40 @@ impl LogicalPlan {
928928
}
929929
}
930930
}
931-
/// Convert a prepared [`LogicalPlan`] into its inner logical plan
932-
/// with all params replaced with their corresponding values
931+
/// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`]
932+
/// with the specified `param_values`.
933+
///
934+
/// [`LogicalPlan::Prepare`] are
935+
/// converted to their inner logical plan for execution.
936+
///
937+
/// # Example
938+
/// ```
939+
/// # use arrow::datatypes::{Field, Schema, DataType};
940+
/// use datafusion_common::ScalarValue;
941+
/// # use datafusion_expr::{lit, col, LogicalPlanBuilder, logical_plan::table_scan, placeholder};
942+
/// # let schema = Schema::new(vec![
943+
/// # Field::new("id", DataType::Int32, false),
944+
/// # ]);
945+
/// // Build SELECT * FROM t1 WHRERE id = $1
946+
/// let plan = table_scan(Some("t1"), &schema, None).unwrap()
947+
/// .filter(col("id").eq(placeholder("$1"))).unwrap()
948+
/// .build().unwrap();
949+
///
950+
/// assert_eq!("Filter: t1.id = $1\
951+
/// \n TableScan: t1",
952+
/// plan.display_indent().to_string()
953+
/// );
954+
///
955+
/// // Fill in the parameter $1 with a literal 3
956+
/// let plan = plan.with_param_values(vec![
957+
/// ScalarValue::from(3i32) // value at index 0 --> $1
958+
/// ]).unwrap();
959+
///
960+
/// assert_eq!("Filter: t1.id = Int32(3)\
961+
/// \n TableScan: t1",
962+
/// plan.display_indent().to_string()
963+
/// );
964+
/// ```
933965
pub fn with_param_values(
934966
self,
935967
param_values: Vec<ScalarValue>,
@@ -961,7 +993,7 @@ impl LogicalPlan {
961993
let input_plan = prepare_lp.input;
962994
input_plan.replace_params_with_values(&param_values)
963995
}
964-
_ => Ok(self),
996+
_ => self.replace_params_with_values(&param_values),
965997
}
966998
}
967999

@@ -1060,7 +1092,7 @@ impl LogicalPlan {
10601092
}
10611093

10621094
impl LogicalPlan {
1063-
/// applies collect to any subqueries in the plan
1095+
/// applies `op` to any subqueries in the plan
10641096
pub(crate) fn apply_subqueries<F>(&self, op: &mut F) -> datafusion_common::Result<()>
10651097
where
10661098
F: FnMut(&Self) -> datafusion_common::Result<VisitRecursion>,
@@ -1112,17 +1144,22 @@ impl LogicalPlan {
11121144
Ok(())
11131145
}
11141146

1115-
/// Return a logical plan with all placeholders/params (e.g $1 $2,
1116-
/// ...) replaced with corresponding values provided in the
1117-
/// params_values
1147+
/// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
1148+
/// ...) replaced with corresponding values provided in
1149+
/// `params_values`
1150+
///
1151+
/// See [`Self::with_param_values`] for examples and usage
11181152
pub fn replace_params_with_values(
11191153
&self,
11201154
param_values: &[ScalarValue],
11211155
) -> Result<LogicalPlan> {
11221156
let new_exprs = self
11231157
.expressions()
11241158
.into_iter()
1125-
.map(|e| Self::replace_placeholders_with_values(e, param_values))
1159+
.map(|e| {
1160+
let e = e.infer_placeholder_types(self.schema())?;
1161+
Self::replace_placeholders_with_values(e, param_values)
1162+
})
11261163
.collect::<Result<Vec<_>>>()?;
11271164

11281165
let new_inputs_with_values = self
@@ -1219,7 +1256,9 @@ impl LogicalPlan {
12191256
// Various implementations for printing out LogicalPlans
12201257
impl LogicalPlan {
12211258
/// Return a `format`able structure that produces a single line
1222-
/// per node. For example:
1259+
/// per node.
1260+
///
1261+
/// # Example
12231262
///
12241263
/// ```text
12251264
/// Projection: employee.id
@@ -2321,7 +2360,7 @@ pub struct Unnest {
23212360
mod tests {
23222361
use super::*;
23232362
use crate::logical_plan::table_scan;
2324-
use crate::{col, exists, in_subquery, lit};
2363+
use crate::{col, exists, in_subquery, lit, placeholder};
23252364
use arrow::datatypes::{DataType, Field, Schema};
23262365
use datafusion_common::tree_node::TreeNodeVisitor;
23272366
use datafusion_common::{not_impl_err, DFSchema, TableReference};
@@ -2767,10 +2806,7 @@ digraph {
27672806

27682807
let plan = table_scan(TableReference::none(), &schema, None)
27692808
.unwrap()
2770-
.filter(col("id").eq(Expr::Placeholder(Placeholder::new(
2771-
"".into(),
2772-
Some(DataType::Int32),
2773-
))))
2809+
.filter(col("id").eq(placeholder("")))
27742810
.unwrap()
27752811
.build()
27762812
.unwrap();
@@ -2783,10 +2819,7 @@ digraph {
27832819

27842820
let plan = table_scan(TableReference::none(), &schema, None)
27852821
.unwrap()
2786-
.filter(col("id").eq(Expr::Placeholder(Placeholder::new(
2787-
"$0".into(),
2788-
Some(DataType::Int32),
2789-
))))
2822+
.filter(col("id").eq(placeholder("$0")))
27902823
.unwrap()
27912824
.build()
27922825
.unwrap();

datafusion/sql/src/expr/mod.rs

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,12 @@ mod value;
2929

3030
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
3131
use arrow_schema::DataType;
32-
use datafusion_common::tree_node::{Transformed, TreeNode};
3332
use datafusion_common::{
3433
internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result,
3534
ScalarValue,
3635
};
36+
use datafusion_expr::expr::InList;
3737
use datafusion_expr::expr::ScalarFunction;
38-
use datafusion_expr::expr::{InList, Placeholder};
3938
use datafusion_expr::{
4039
col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast,
4140
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast,
@@ -122,7 +121,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
122121
let mut expr = self.sql_expr_to_logical_expr(sql, schema, planner_context)?;
123122
expr = self.rewrite_partial_qualifier(expr, schema);
124123
self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
125-
let expr = infer_placeholder_types(expr, schema)?;
124+
let expr = expr.infer_placeholder_types(schema)?;
126125
Ok(expr)
127126
}
128127

@@ -712,49 +711,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
712711
}
713712
}
714713

715-
// modifies expr if it is a placeholder with datatype of right
716-
fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> {
717-
if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
718-
if data_type.is_none() {
719-
let other_dt = other.get_type(schema);
720-
match other_dt {
721-
Err(e) => {
722-
Err(e.context(format!(
723-
"Can not find type of {other} needed to infer type of {expr}"
724-
)))?;
725-
}
726-
Ok(dt) => {
727-
*data_type = Some(dt);
728-
}
729-
}
730-
};
731-
}
732-
Ok(())
733-
}
734-
735-
/// Find all [`Expr::Placeholder`] tokens in a logical plan, and try
736-
/// to infer their [`DataType`] from the context of their use.
737-
fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result<Expr> {
738-
expr.transform(&|mut expr| {
739-
// Default to assuming the arguments are the same type
740-
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
741-
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
742-
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
743-
};
744-
if let Expr::Between(Between {
745-
expr,
746-
negated: _,
747-
low,
748-
high,
749-
}) = &mut expr
750-
{
751-
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
752-
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
753-
}
754-
Ok(Transformed::Yes(expr))
755-
})
756-
}
757-
758714
#[cfg(test)]
759715
mod tests {
760716
use super::*;

datafusion/sql/tests/sql_integration.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3684,6 +3684,19 @@ fn test_prepare_statement_should_infer_types() {
36843684
assert_eq!(actual_types, expected_types);
36853685
}
36863686

3687+
#[test]
3688+
fn test_non_prepare_statement_should_infer_types() {
3689+
// Non prepared statements (like SELECT) should also have their parameter types inferred
3690+
let sql = "SELECT 1 + $1";
3691+
let plan = logical_plan(sql).unwrap();
3692+
let actual_types = plan.get_parameter_types().unwrap();
3693+
let expected_types = HashMap::from([
3694+
// constant 1 is inferred to be int64
3695+
("$1".to_string(), Some(DataType::Int64)),
3696+
]);
3697+
assert_eq!(actual_types, expected_types);
3698+
}
3699+
36873700
#[test]
36883701
#[should_panic(
36893702
expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\""

0 commit comments

Comments
 (0)