Skip to content

Commit 78a4bba

Browse files
shivbhatia10Shiv Bhatia
andauthored
Fix async_udf batch size behaviour (#18819)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #18822. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> This PR fixes the bug outlined in the issue, we shouldn't use `ColumnarValue::values_to_arrays` on the batches collected in `async_scalar_function.rs`. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Added a test to cover this behaviour and fixed the issue in the async scalar function physical expression. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes, I added a new `user_defined_async_scalar_functions.rs` test file similar to `user_defined_scalar_functions.rs` which contains a test that covers this behaviour. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> Yes <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Shiv Bhatia <sbhatia@palantir.com>
1 parent d65fb86 commit 78a4bba

File tree

3 files changed

+157
-10
lines changed

3 files changed

+157
-10
lines changed

datafusion/core/tests/user_defined/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
/// Tests for user defined Async Scalar functions
19+
mod user_defined_async_scalar_functions;
20+
1821
/// Tests for user defined Scalar functions
1922
mod user_defined_scalar_functions;
2023

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Int32Array, RecordBatch, StringArray};
21+
use arrow::datatypes::{DataType, Field, Schema};
22+
use async_trait::async_trait;
23+
use datafusion::prelude::*;
24+
use datafusion_common::{assert_batches_eq, Result};
25+
use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
26+
use datafusion_expr::{
27+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
28+
};
29+
30+
// This test checks the case where batch_size doesn't evenly divide
31+
// the number of rows.
32+
#[tokio::test]
33+
async fn test_async_udf_with_non_modular_batch_size() -> Result<()> {
34+
let num_rows = 3;
35+
let batch_size = 2;
36+
37+
let schema = Arc::new(Schema::new(vec![
38+
Field::new("id", DataType::Int32, false),
39+
Field::new("prompt", DataType::Utf8, false),
40+
]));
41+
42+
let batch = RecordBatch::try_new(
43+
schema.clone(),
44+
vec![
45+
Arc::new(Int32Array::from((0..num_rows).collect::<Vec<i32>>())),
46+
Arc::new(StringArray::from(
47+
(0..num_rows)
48+
.map(|i| format!("prompt{i}"))
49+
.collect::<Vec<_>>(),
50+
)),
51+
],
52+
)?;
53+
54+
let ctx = SessionContext::new();
55+
ctx.register_batch("test_table", batch)?;
56+
57+
ctx.register_udf(
58+
AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(batch_size)))
59+
.into_scalar_udf(),
60+
);
61+
62+
let df = ctx
63+
.sql("SELECT id, test_async_udf(prompt) as result FROM test_table")
64+
.await?;
65+
66+
let result = df.collect().await?;
67+
68+
assert_batches_eq!(
69+
&[
70+
"+----+---------+",
71+
"| id | result |",
72+
"+----+---------+",
73+
"| 0 | prompt0 |",
74+
"| 1 | prompt1 |",
75+
"| 2 | prompt2 |",
76+
"+----+---------+"
77+
],
78+
&result
79+
);
80+
81+
Ok(())
82+
}
83+
84+
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
85+
struct TestAsyncUDFImpl {
86+
batch_size: usize,
87+
signature: Signature,
88+
}
89+
90+
impl TestAsyncUDFImpl {
91+
fn new(batch_size: usize) -> Self {
92+
Self {
93+
batch_size,
94+
signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile),
95+
}
96+
}
97+
}
98+
99+
impl ScalarUDFImpl for TestAsyncUDFImpl {
100+
fn as_any(&self) -> &dyn std::any::Any {
101+
self
102+
}
103+
104+
fn name(&self) -> &str {
105+
"test_async_udf"
106+
}
107+
108+
fn signature(&self) -> &Signature {
109+
&self.signature
110+
}
111+
112+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
113+
Ok(DataType::Utf8)
114+
}
115+
116+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
117+
panic!("Call invoke_async_with_args instead")
118+
}
119+
}
120+
121+
#[async_trait]
122+
impl AsyncScalarUDFImpl for TestAsyncUDFImpl {
123+
fn ideal_batch_size(&self) -> Option<usize> {
124+
Some(self.batch_size)
125+
}
126+
async fn invoke_async_with_args(
127+
&self,
128+
args: ScalarFunctionArgs,
129+
) -> Result<ColumnarValue> {
130+
let arg1 = &args.args[0];
131+
let results = call_external_service(arg1.clone()).await?;
132+
Ok(results)
133+
}
134+
}
135+
136+
/// Simulates calling an async external service
137+
async fn call_external_service(arg1: ColumnarValue) -> Result<ColumnarValue> {
138+
Ok(arg1)
139+
}

datafusion/physical-expr/src/async_scalar_function.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
// under the License.
1717

1818
use crate::ScalarFunctionExpr;
19-
use arrow::array::{make_array, MutableArrayData, RecordBatch};
19+
use arrow::array::RecordBatch;
20+
use arrow::compute::concat;
2021
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
2122
use datafusion_common::config::ConfigOptions;
2223
use datafusion_common::Result;
@@ -192,17 +193,21 @@ impl AsyncFuncExpr {
192193
);
193194
}
194195

195-
let datas = ColumnarValue::values_to_arrays(&result_batches)?
196+
let datas = result_batches
197+
.into_iter()
198+
.map(|cv| match cv {
199+
ColumnarValue::Array(arr) => Ok(arr),
200+
ColumnarValue::Scalar(scalar) => Ok(scalar.to_array_of_size(1)?),
201+
})
202+
.collect::<Result<Vec<_>>>()?;
203+
204+
// Get references to the arrays as dyn Array to call concat
205+
let dyn_arrays = datas
196206
.iter()
197-
.map(|b| b.to_data())
207+
.map(|arr| arr as &dyn arrow::array::Array)
198208
.collect::<Vec<_>>();
199-
let total_len = datas.iter().map(|d| d.len()).sum();
200-
let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len);
201-
datas.iter().enumerate().for_each(|(i, data)| {
202-
mutable.extend(i, 0, data.len());
203-
});
204-
let array_ref = make_array(mutable.freeze());
205-
Ok(ColumnarValue::Array(array_ref))
209+
let result_array = concat(&dyn_arrays)?;
210+
Ok(ColumnarValue::Array(result_array))
206211
}
207212
}
208213

0 commit comments

Comments
 (0)