Skip to content

Commit 2a4f94e

Browse files
authored
update python crate (#768)
1 parent 31d16d2 commit 2a4f94e

File tree

3 files changed

+57
-44
lines changed

3 files changed

+57
-44
lines changed

python/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ libc = "0.2"
3131
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3232
rand = "0.7"
3333
pyo3 = { version = "0.14.1", features = ["extension-module"] }
34-
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "e4df37a4001423909964348289360da66acdd0a3" }
34+
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" }
3535

3636
[lib]
3737
name = "datafusion"

python/src/functions.rs

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::udf;
2020
use crate::{expression, types::PyDataType};
2121
use datafusion::arrow::datatypes::DataType;
2222
use datafusion::logical_plan;
23-
use pyo3::{prelude::*, wrap_pyfunction};
23+
use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction};
2424
use std::sync::Arc;
2525

2626
/// Expression representing a column on the existing plan.
@@ -76,11 +76,33 @@ fn now() -> expression::Expression {
7676
#[pyfunction]
7777
fn random() -> expression::Expression {
7878
expression::Expression {
79-
// here lit(0) is a stub for conform to arity
80-
expr: logical_plan::random(logical_plan::lit(0)),
79+
expr: logical_plan::random(),
8180
}
8281
}
8382

83+
/// Concatenates the text representations of all the arguments.
84+
/// NULL arguments are ignored.
85+
#[pyfunction(args = "*")]
86+
fn concat(args: &PyTuple) -> PyResult<expression::Expression> {
87+
let expressions = expression::from_tuple(args)?;
88+
let args = expressions.into_iter().map(|e| e.expr).collect::<Vec<_>>();
89+
Ok(expression::Expression {
90+
expr: logical_plan::concat(&args),
91+
})
92+
}
93+
94+
/// Concatenates all but the first argument, with separators.
95+
/// The first argument is used as the separator string, and should not be NULL.
96+
/// Other NULL arguments are ignored.
97+
#[pyfunction(sep, args = "*")]
98+
fn concat_ws(sep: String, args: &PyTuple) -> PyResult<expression::Expression> {
99+
let expressions = expression::from_tuple(args)?;
100+
let args = expressions.into_iter().map(|e| e.expr).collect::<Vec<_>>();
101+
Ok(expression::Expression {
102+
expr: logical_plan::concat_ws(sep, &args),
103+
})
104+
}
105+
84106
macro_rules! define_unary_function {
85107
($NAME: ident) => {
86108
#[doc = "This function is not documented yet"]
@@ -132,7 +154,6 @@ define_unary_function!(
132154
"Returns number of characters in the string."
133155
);
134156
define_unary_function!(chr, "Returns the character with the given code.");
135-
define_unary_function!(concat_ws, "Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored.");
136157
define_unary_function!(initcap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters.");
137158
define_unary_function!(left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters.");
138159
define_unary_function!(lower, "Converts the string to all lower case");
@@ -179,15 +200,6 @@ define_unary_function!(min);
179200
define_unary_function!(max);
180201
define_unary_function!(count);
181202

182-
/*
183-
#[pyfunction]
184-
fn concat(value: Vec<expression::Expression>) -> expression::Expression {
185-
expression::Expression {
186-
expr: logical_plan::concat(value.into_iter().map(|e| e.expr)),
187-
}
188-
}
189-
*/
190-
191203
pub(crate) fn create_udf(
192204
fun: PyObject,
193205
input_types: Vec<PyDataType>,
@@ -250,70 +262,69 @@ fn udaf(
250262
}
251263

252264
pub fn init(module: &PyModule) -> PyResult<()> {
253-
module.add_function(wrap_pyfunction!(col, module)?)?;
254-
module.add_function(wrap_pyfunction!(lit, module)?)?;
255-
// see https://github.com/apache/arrow-datafusion/issues/226
256-
//module.add_function(wrap_pyfunction!(concat, module)?)?;
257-
module.add_function(wrap_pyfunction!(udf, module)?)?;
265+
module.add_function(wrap_pyfunction!(abs, module)?)?;
266+
module.add_function(wrap_pyfunction!(acos, module)?)?;
258267
module.add_function(wrap_pyfunction!(array, module)?)?;
259268
module.add_function(wrap_pyfunction!(ascii, module)?)?;
269+
module.add_function(wrap_pyfunction!(asin, module)?)?;
270+
module.add_function(wrap_pyfunction!(atan, module)?)?;
271+
module.add_function(wrap_pyfunction!(avg, module)?)?;
260272
module.add_function(wrap_pyfunction!(bit_length, module)?)?;
273+
module.add_function(wrap_pyfunction!(btrim, module)?)?;
274+
module.add_function(wrap_pyfunction!(ceil, module)?)?;
261275
module.add_function(wrap_pyfunction!(character_length, module)?)?;
262276
module.add_function(wrap_pyfunction!(chr, module)?)?;
263-
module.add_function(wrap_pyfunction!(btrim, module)?)?;
277+
module.add_function(wrap_pyfunction!(col, module)?)?;
264278
module.add_function(wrap_pyfunction!(concat_ws, module)?)?;
279+
module.add_function(wrap_pyfunction!(concat, module)?)?;
280+
module.add_function(wrap_pyfunction!(cos, module)?)?;
281+
module.add_function(wrap_pyfunction!(count, module)?)?;
282+
module.add_function(wrap_pyfunction!(exp, module)?)?;
283+
module.add_function(wrap_pyfunction!(floor, module)?)?;
265284
module.add_function(wrap_pyfunction!(in_list, module)?)?;
266285
module.add_function(wrap_pyfunction!(initcap, module)?)?;
267286
module.add_function(wrap_pyfunction!(left, module)?)?;
287+
module.add_function(wrap_pyfunction!(lit, module)?)?;
288+
module.add_function(wrap_pyfunction!(ln, module)?)?;
289+
module.add_function(wrap_pyfunction!(log10, module)?)?;
290+
module.add_function(wrap_pyfunction!(log2, module)?)?;
268291
module.add_function(wrap_pyfunction!(lower, module)?)?;
269292
module.add_function(wrap_pyfunction!(lpad, module)?)?;
293+
module.add_function(wrap_pyfunction!(ltrim, module)?)?;
294+
module.add_function(wrap_pyfunction!(max, module)?)?;
270295
module.add_function(wrap_pyfunction!(md5, module)?)?;
296+
module.add_function(wrap_pyfunction!(min, module)?)?;
271297
module.add_function(wrap_pyfunction!(now, module)?)?;
272-
module.add_function(wrap_pyfunction!(ltrim, module)?)?;
273298
module.add_function(wrap_pyfunction!(octet_length, module)?)?;
274299
module.add_function(wrap_pyfunction!(random, module)?)?;
275300
module.add_function(wrap_pyfunction!(regexp_replace, module)?)?;
276301
module.add_function(wrap_pyfunction!(repeat, module)?)?;
277302
module.add_function(wrap_pyfunction!(replace, module)?)?;
278303
module.add_function(wrap_pyfunction!(reverse, module)?)?;
279304
module.add_function(wrap_pyfunction!(right, module)?)?;
305+
module.add_function(wrap_pyfunction!(round, module)?)?;
280306
module.add_function(wrap_pyfunction!(rpad, module)?)?;
281307
module.add_function(wrap_pyfunction!(rtrim, module)?)?;
282308
module.add_function(wrap_pyfunction!(sha224, module)?)?;
283309
module.add_function(wrap_pyfunction!(sha256, module)?)?;
284310
module.add_function(wrap_pyfunction!(sha384, module)?)?;
285311
module.add_function(wrap_pyfunction!(sha512, module)?)?;
312+
module.add_function(wrap_pyfunction!(signum, module)?)?;
313+
module.add_function(wrap_pyfunction!(sin, module)?)?;
286314
module.add_function(wrap_pyfunction!(split_part, module)?)?;
315+
module.add_function(wrap_pyfunction!(sqrt, module)?)?;
287316
module.add_function(wrap_pyfunction!(starts_with, module)?)?;
288317
module.add_function(wrap_pyfunction!(strpos, module)?)?;
289318
module.add_function(wrap_pyfunction!(substr, module)?)?;
319+
module.add_function(wrap_pyfunction!(sum, module)?)?;
320+
module.add_function(wrap_pyfunction!(tan, module)?)?;
290321
module.add_function(wrap_pyfunction!(to_hex, module)?)?;
291322
module.add_function(wrap_pyfunction!(translate, module)?)?;
292323
module.add_function(wrap_pyfunction!(trim, module)?)?;
293-
module.add_function(wrap_pyfunction!(upper, module)?)?;
294-
module.add_function(wrap_pyfunction!(sum, module)?)?;
295-
module.add_function(wrap_pyfunction!(count, module)?)?;
296-
module.add_function(wrap_pyfunction!(min, module)?)?;
297-
module.add_function(wrap_pyfunction!(max, module)?)?;
298-
module.add_function(wrap_pyfunction!(avg, module)?)?;
299-
module.add_function(wrap_pyfunction!(udaf, module)?)?;
300-
module.add_function(wrap_pyfunction!(sqrt, module)?)?;
301-
module.add_function(wrap_pyfunction!(sin, module)?)?;
302-
module.add_function(wrap_pyfunction!(cos, module)?)?;
303-
module.add_function(wrap_pyfunction!(tan, module)?)?;
304-
module.add_function(wrap_pyfunction!(asin, module)?)?;
305-
module.add_function(wrap_pyfunction!(acos, module)?)?;
306-
module.add_function(wrap_pyfunction!(atan, module)?)?;
307-
module.add_function(wrap_pyfunction!(floor, module)?)?;
308-
module.add_function(wrap_pyfunction!(ceil, module)?)?;
309-
module.add_function(wrap_pyfunction!(round, module)?)?;
310324
module.add_function(wrap_pyfunction!(trunc, module)?)?;
311-
module.add_function(wrap_pyfunction!(abs, module)?)?;
312-
module.add_function(wrap_pyfunction!(signum, module)?)?;
313-
module.add_function(wrap_pyfunction!(exp, module)?)?;
314-
module.add_function(wrap_pyfunction!(ln, module)?)?;
315-
module.add_function(wrap_pyfunction!(log2, module)?)?;
316-
module.add_function(wrap_pyfunction!(log10, module)?)?;
325+
module.add_function(wrap_pyfunction!(udaf, module)?)?;
326+
module.add_function(wrap_pyfunction!(udf, module)?)?;
327+
module.add_function(wrap_pyfunction!(upper, module)?)?;
317328

318329
Ok(())
319330
}

python/tests/test_math_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_math_functions(df):
4444
f.ln(col_v + f.lit(1)),
4545
f.log2(col_v + f.lit(1)),
4646
f.log10(col_v + f.lit(1)),
47+
f.random(),
4748
)
4849
result = df.collect()
4950
assert len(result) == 1
@@ -58,3 +59,4 @@ def test_math_functions(df):
5859
np.testing.assert_array_almost_equal(result.column(7), np.log(values + 1.0))
5960
np.testing.assert_array_almost_equal(result.column(8), np.log2(values + 1.0))
6061
np.testing.assert_array_almost_equal(result.column(9), np.log10(values + 1.0))
62+
np.testing.assert_array_less(result.column(10), np.ones_like(values))

0 commit comments

Comments
 (0)