Skip to content

Commit 6c050b8

Browse files
authored
add random SQL function (#303)
* fix 305 * add supports_zero_argument * fix unit test
1 parent 26b78c6 commit 6c050b8

File tree

5 files changed

+57
-4
lines changed

5 files changed

+57
-4
lines changed

datafusion/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ unicode-segmentation = { version = "^1.7.1", optional = true }
6565
regex = { version = "^1.4.3", optional = true }
6666
lazy_static = { version = "^1.4.0", optional = true }
6767
smallvec = { version = "1.6", features = ["union"] }
68+
rand = "0.8"
6869

6970
[dev-dependencies]
70-
rand = "0.8"
7171
criterion = "0.3"
7272
tempfile = "3"
7373
doc-comment = "0.3"

datafusion/src/physical_plan/functions.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ pub enum BuiltinScalarFunction {
169169
NullIf,
170170
/// octet_length
171171
OctetLength,
172+
/// random
173+
Random,
172174
/// regexp_replace
173175
RegexpReplace,
174176
/// repeat
@@ -219,7 +221,10 @@ impl BuiltinScalarFunction {
219221
/// an allowlist of functions to take zero arguments, so that they will get special treatment
220222
/// while executing.
221223
fn supports_zero_argument(&self) -> bool {
222-
matches!(self, BuiltinScalarFunction::Now)
224+
matches!(
225+
self,
226+
BuiltinScalarFunction::Random | BuiltinScalarFunction::Now
227+
)
223228
}
224229
}
225230

@@ -275,6 +280,7 @@ impl FromStr for BuiltinScalarFunction {
275280
"md5" => BuiltinScalarFunction::MD5,
276281
"nullif" => BuiltinScalarFunction::NullIf,
277282
"octet_length" => BuiltinScalarFunction::OctetLength,
283+
"random" => BuiltinScalarFunction::Random,
278284
"regexp_replace" => BuiltinScalarFunction::RegexpReplace,
279285
"repeat" => BuiltinScalarFunction::Repeat,
280286
"replace" => BuiltinScalarFunction::Replace,
@@ -438,6 +444,7 @@ pub fn return_type(
438444
));
439445
}
440446
}),
447+
BuiltinScalarFunction::Random => Ok(DataType::Float64),
441448
BuiltinScalarFunction::RegexpReplace => Ok(match arg_types[0] {
442449
DataType::LargeUtf8 => DataType::LargeUtf8,
443450
DataType::Utf8 => DataType::Utf8,
@@ -742,6 +749,7 @@ pub fn create_physical_expr(
742749
BuiltinScalarFunction::Ln => math_expressions::ln,
743750
BuiltinScalarFunction::Log10 => math_expressions::log10,
744751
BuiltinScalarFunction::Log2 => math_expressions::log2,
752+
BuiltinScalarFunction::Random => math_expressions::random,
745753
BuiltinScalarFunction::Round => math_expressions::round,
746754
BuiltinScalarFunction::Signum => math_expressions::signum,
747755
BuiltinScalarFunction::Sin => math_expressions::sin,
@@ -1307,6 +1315,7 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
13071315
Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
13081316
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
13091317
]),
1318+
BuiltinScalarFunction::Random => Signature::Exact(vec![]),
13101319
// math expressions expect 1 argument of type f64 or f32
13111320
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
13121321
// return the best approximation for it (in f64).

datafusion/src/physical_plan/math_expressions.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
// under the License.
1717

1818
//! Math expressions
19-
2019
use super::{ColumnarValue, ScalarValue};
2120
use crate::error::{DataFusionError, Result};
2221
use arrow::array::{Float32Array, Float64Array};
2322
use arrow::datatypes::DataType;
23+
use rand::{thread_rng, Rng};
24+
use std::iter;
2425
use std::sync::Arc;
2526

2627
macro_rules! downcast_compute_op {
@@ -100,3 +101,36 @@ math_unary_function!("exp", exp);
100101
math_unary_function!("ln", ln);
101102
math_unary_function!("log2", log2);
102103
math_unary_function!("log10", log10);
104+
105+
/// random SQL function
106+
pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
107+
let len: usize = match &args[0] {
108+
ColumnarValue::Array(array) => array.len(),
109+
_ => {
110+
return Err(DataFusionError::Internal(
111+
"Expect random function to take no param".to_string(),
112+
))
113+
}
114+
};
115+
let mut rng = thread_rng();
116+
let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len);
117+
let array = Float64Array::from_iter_values(values);
118+
Ok(ColumnarValue::Array(Arc::new(array)))
119+
}
120+
121+
#[cfg(test)]
122+
mod tests {
123+
124+
use super::*;
125+
use arrow::array::{Float64Array, NullArray};
126+
127+
#[test]
128+
fn test_random_expression() {
129+
let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))];
130+
let array = random(&args).expect("fail").into_array(1);
131+
let floats = array.as_any().downcast_ref::<Float64Array>().expect("fail");
132+
133+
assert_eq!(floats.len(), 1);
134+
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
135+
}
136+
}

datafusion/src/physical_plan/type_coercion.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ pub fn data_types(
7575
if current_types.is_empty() {
7676
return Ok(vec![]);
7777
}
78-
7978
let valid_types = get_valid_types(signature, current_types)?;
8079

8180
if valid_types

datafusion/tests/sql.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2910,6 +2910,17 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> {
29102910
Ok(())
29112911
}
29122912

2913+
#[tokio::test]
2914+
async fn test_random_expression() -> Result<()> {
2915+
let mut ctx = create_ctx()?;
2916+
let sql = "SELECT random() r1";
2917+
let actual = execute(&mut ctx, sql).await;
2918+
let r1 = actual[0][0].parse::<f64>().unwrap();
2919+
assert!(0.0 <= r1);
2920+
assert!(r1 < 1.0);
2921+
Ok(())
2922+
}
2923+
29132924
#[tokio::test]
29142925
async fn test_cast_expressions_error() -> Result<()> {
29152926
// sin(utf8) should error

0 commit comments

Comments
 (0)