Skip to content

Commit 20d35fa

Browse files
author
Jiayu Liu
committed
fix 305
1 parent ce089f4 commit 20d35fa

File tree

5 files changed

+51
-3
lines changed

5 files changed

+51
-3
lines changed

datafusion/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ unicode-segmentation = { version = "^1.7.1", optional = true }
6565
regex = { version = "^1.4.3", optional = true }
6666
lazy_static = { version = "^1.4.0", optional = true }
6767
smallvec = { version = "1.6", features = ["union"] }
68+
rand = "0.8"
6869

6970
[dev-dependencies]
70-
rand = "0.8"
7171
criterion = "0.3"
7272
tempfile = "3"
7373
doc-comment = "0.3"

datafusion/src/physical_plan/functions.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ pub enum BuiltinScalarFunction {
169169
NullIf,
170170
/// octet_length
171171
OctetLength,
172+
/// random
173+
Random,
172174
/// regexp_replace
173175
RegexpReplace,
174176
/// repeat
@@ -267,6 +269,7 @@ impl FromStr for BuiltinScalarFunction {
267269
"md5" => BuiltinScalarFunction::MD5,
268270
"nullif" => BuiltinScalarFunction::NullIf,
269271
"octet_length" => BuiltinScalarFunction::OctetLength,
272+
"random" => BuiltinScalarFunction::Random,
270273
"regexp_replace" => BuiltinScalarFunction::RegexpReplace,
271274
"repeat" => BuiltinScalarFunction::Repeat,
272275
"replace" => BuiltinScalarFunction::Replace,
@@ -430,6 +433,7 @@ pub fn return_type(
430433
));
431434
}
432435
}),
436+
BuiltinScalarFunction::Random => Ok(DataType::Float64),
433437
BuiltinScalarFunction::RegexpReplace => Ok(match arg_types[0] {
434438
DataType::LargeUtf8 => DataType::LargeUtf8,
435439
DataType::Utf8 => DataType::Utf8,
@@ -734,6 +738,7 @@ pub fn create_physical_expr(
734738
BuiltinScalarFunction::Ln => math_expressions::ln,
735739
BuiltinScalarFunction::Log10 => math_expressions::log10,
736740
BuiltinScalarFunction::Log2 => math_expressions::log2,
741+
BuiltinScalarFunction::Random => math_expressions::random,
737742
BuiltinScalarFunction::Round => math_expressions::round,
738743
BuiltinScalarFunction::Signum => math_expressions::signum,
739744
BuiltinScalarFunction::Sin => math_expressions::sin,
@@ -1299,6 +1304,7 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
12991304
Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
13001305
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
13011306
]),
1307+
BuiltinScalarFunction::Random => Signature::Exact(vec![]),
13021308
// math expressions expect 1 argument of type f64 or f32
13031309
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
13041310
// return the best approximation for it (in f64).

datafusion/src/physical_plan/math_expressions.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
// under the License.
1717

1818
//! Math expressions
19-
2019
use super::{ColumnarValue, ScalarValue};
2120
use crate::error::{DataFusionError, Result};
2221
use arrow::array::{Float32Array, Float64Array};
2322
use arrow::datatypes::DataType;
23+
use rand::{thread_rng, Rng};
24+
use std::iter;
2425
use std::sync::Arc;
2526

2627
macro_rules! downcast_compute_op {
@@ -100,3 +101,36 @@ math_unary_function!("exp", exp);
100101
math_unary_function!("ln", ln);
101102
math_unary_function!("log2", log2);
102103
math_unary_function!("log10", log10);
104+
105+
/// random SQL function
106+
pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
107+
let len: usize = match &args[0] {
108+
ColumnarValue::Array(array) => array.len(),
109+
_ => {
110+
return Err(DataFusionError::Internal(
111+
"Expect random function to take no param".to_string(),
112+
))
113+
}
114+
};
115+
let mut rng = thread_rng();
116+
let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len);
117+
let array = Float64Array::from_iter_values(values);
118+
Ok(ColumnarValue::Array(Arc::new(array)))
119+
}
120+
121+
#[cfg(test)]
122+
mod tests {
123+
124+
use super::*;
125+
use arrow::array::{Float64Array, NullArray};
126+
127+
#[test]
128+
fn test_random_expression() {
129+
let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))];
130+
let array = random(&args).expect("fail").into_array(1);
131+
let floats = array.as_any().downcast_ref::<Float64Array>().expect("fail");
132+
133+
assert_eq!(floats.len(), 1);
134+
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
135+
}
136+
}

datafusion/src/physical_plan/type_coercion.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ pub fn data_types(
7575
if current_types.is_empty() {
7676
return Ok(vec![]);
7777
}
78-
7978
let valid_types = get_valid_types(signature, current_types)?;
8079

8180
if valid_types

datafusion/tests/sql.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2906,7 +2906,16 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> {
29062906
let t2 = t2_naive.timestamp();
29072907
assert!(t1 <= t2 && t2 <= t3);
29082908
assert_eq!(res2, res1);
2909+
}
29092910

2911+
#[tokio::test]
2912+
async fn test_random_expression() -> Result<()> {
2913+
let mut ctx = create_ctx()?;
2914+
let sql = format!("SELECT random() r1");
2915+
let actual = execute(&mut ctx, sql.as_str()).await;
2916+
let r1 = actual[0][0].parse::<f64>().unwrap();
2917+
assert!(0.0 <= r1);
2918+
assert!(r1 < 1.0);
29102919
Ok(())
29112920
}
29122921

0 commit comments

Comments
 (0)