Skip to content

Commit

Permalink
fix: allow placeholders to be substited when coercible
Browse files Browse the repository at this point in the history
  • Loading branch information
erratic-pattern committed Jan 24, 2024
1 parent ee7ab0b commit 287594d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 25 deletions.
24 changes: 2 additions & 22 deletions datafusion/common/src/param_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::error::{_internal_err, _plan_err};
use crate::error::_plan_err;
use crate::{DataFusionError, Result, ScalarValue};
use arrow_schema::DataType;
use std::collections::HashMap;
Expand Down Expand Up @@ -65,11 +65,7 @@ impl ParamValues {
}
}

pub fn get_placeholders_with_values(
&self,
id: &str,
data_type: Option<&DataType>,
) -> Result<ScalarValue> {
pub fn get_placeholders_with_values(&self, id: &str) -> Result<ScalarValue> {
match self {
ParamValues::List(list) => {
if id.is_empty() {
Expand All @@ -90,14 +86,6 @@ impl ParamValues {
"No value found for placeholder with id {id}"
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(&value.data_type()) != data_type {
return _internal_err!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.data_type()
);
}
Ok(value.clone())
}
ParamValues::Map(map) => {
Expand All @@ -109,14 +97,6 @@ impl ParamValues {
"No value found for placeholder with name {id}"
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(&value.data_type()) != data_type {
return _internal_err!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.data_type()
);
}
Ok(value.clone())
}
}
Expand Down
137 changes: 137 additions & 0 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,89 @@ async fn test_prepare_statement() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn prepared_statement_type_coercion() -> Result<()> {
let ctx = SessionContext::new();
let signed_ints: Int32Array = vec![-1, 0, 1].into();
let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
let batch = RecordBatch::try_from_iter(vec![
("signed", Arc::new(signed_ints) as ArrayRef),
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
])?;
ctx.register_batch("test", batch)?;
let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3")
.await?
.with_param_values(vec![
ScalarValue::from(1_i64),
ScalarValue::from(-1_i32),
ScalarValue::from("1"),
])?
.collect()
.await?;
let expected = vec![
"+--------+----------+",
"| signed | unsigned |",
"+--------+----------+",
"| -1 | 1 |",
"+--------+----------+",
];
assert_batches_sorted_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn prepared_statement_invalid_types() -> Result<()> {
let ctx = SessionContext::new();
let signed_ints: Int32Array = vec![-1, 0, 1].into();
let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
let batch = RecordBatch::try_from_iter(vec![
("signed", Arc::new(signed_ints) as ArrayRef),
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
])?;
ctx.register_batch("test", batch)?;
let results = ctx
.sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed = $1")
.await?
.with_param_values(vec![ScalarValue::from("1")]);
assert_eq!(
results.unwrap_err().strip_backtrace(),
"Error during planning: Expected parameter of type Int32, got Utf8 at index 0"
);
Ok(())
}

#[tokio::test]
async fn test_list_query_parameters() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;

let results = ctx
.sql("SELECT * FROM test WHERE c1 = $1")
.await?
.with_param_values(vec![ScalarValue::from(3i32)])?
.collect()
.await?;
let expected = vec![
"+----+----+-------+",
"| c1 | c2 | c3 |",
"+----+----+-------+",
"| 3 | 1 | false |",
"| 3 | 10 | true |",
"| 3 | 2 | true |",
"| 3 | 3 | false |",
"| 3 | 4 | true |",
"| 3 | 5 | false |",
"| 3 | 6 | true |",
"| 3 | 7 | false |",
"| 3 | 8 | true |",
"| 3 | 9 | false |",
"+----+----+-------+",
];
assert_batches_sorted_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn test_named_query_parameters() -> Result<()> {
let tmp_dir = TempDir::new()?;
Expand Down Expand Up @@ -572,6 +655,60 @@ async fn test_named_query_parameters() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_parameter_type_coercion() -> Result<()> {
let ctx = SessionContext::new();
let signed_ints: Int32Array = vec![-1, 0, 1].into();
let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
let batch = RecordBatch::try_from_iter(vec![
("signed", Arc::new(signed_ints) as ArrayRef),
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
])?;
ctx.register_batch("test", batch)?;
let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $foo >= signed AND signed <= $bar AND unsigned <= $baz AND unsigned = $str")
.await?
.with_param_values(vec![
("foo", ScalarValue::from(1_u64)),
("bar", ScalarValue::from(-1_i64)),
("baz", ScalarValue::from(2_i32)),
("str", ScalarValue::from("1")),
])?
.collect().await?;
let expected = vec![
"+--------+----------+",
"| signed | unsigned |",
"+--------+----------+",
"| -1 | 1 |",
"+--------+----------+",
];
assert_batches_sorted_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn test_parameter_invalid_types() -> Result<()> {
let ctx = SessionContext::new();
let list_array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![Some(vec![
Some(1),
Some(2),
Some(3),
])]);
let batch =
RecordBatch::try_from_iter(vec![("list", Arc::new(list_array) as ArrayRef)])?;
ctx.register_batch("test", batch)?;
let results = ctx
.sql("SELECT list FROM test WHERE list = $1")
.await?
.with_param_values(vec![ScalarValue::from(4_i32)])?
.collect()
.await;
assert_eq!(
results.unwrap_err().strip_backtrace(),
"Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })"
);
Ok(())
}

#[tokio::test]
async fn parallel_query_with_filter() -> Result<()> {
let tmp_dir = TempDir::new()?;
Expand Down
5 changes: 2 additions & 3 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1243,9 +1243,8 @@ impl LogicalPlan {
) -> Result<Expr> {
expr.transform(&|expr| {
match &expr {
Expr::Placeholder(Placeholder { id, data_type }) => {
let value = param_values
.get_placeholders_with_values(id, data_type.as_ref())?;
Expr::Placeholder(Placeholder { id, .. }) => {
let value = param_values.get_placeholders_with_values(id)?;
// Replace the placeholder with the value
Ok(Transformed::Yes(Expr::Literal(value)))
}
Expand Down

0 comments on commit 287594d

Please sign in to comment.