|
17 | 17 |
|
18 | 18 | //! Defines primitive computations on arrays, e.g. addition, equality, boolean logic. |
19 | 19 |
|
| 20 | +use std::cmp; |
20 | 21 | use std::ops::Add; |
| 22 | +use std::sync::Arc; |
21 | 23 |
|
22 | | -use crate::array::{Array, BooleanArray, PrimitiveArray}; |
23 | | -use crate::datatypes::ArrowNumericType; |
| 24 | +use crate::array::{ |
| 25 | + Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int16Array, |
| 26 | + Int32Array, Int64Array, Int8Array, PrimitiveArray, UInt16Array, UInt32Array, |
| 27 | + UInt64Array, UInt8Array, |
| 28 | +}; |
| 29 | +use crate::datatypes::{ArrowNumericType, DataType}; |
24 | 30 | use crate::error::{ArrowError, Result}; |
25 | 31 |
|
26 | 32 | /// Returns the minimum value in the array, according to the natural order. |
@@ -204,6 +210,101 @@ where |
204 | 210 | Ok(b.finish()) |
205 | 211 | } |
206 | 212 |
|
| 213 | +macro_rules! filter_array { |
| 214 | + ($array:expr, $filter:expr, $array_type:ident) => {{ |
| 215 | + let b = $array.as_any().downcast_ref::<$array_type>().unwrap(); |
| 216 | + let mut builder = $array_type::builder(b.len()); |
| 217 | + for i in 0..b.len() { |
| 218 | + if $filter.value(i) { |
| 219 | + if b.is_null(i) { |
| 220 | + builder.append_null()?; |
| 221 | + } else { |
| 222 | + builder.append_value(b.value(i))?; |
| 223 | + } |
| 224 | + } |
| 225 | + } |
| 226 | + Ok(Arc::new(builder.finish())) |
| 227 | + }}; |
| 228 | +} |
| 229 | + |
| 230 | +pub fn filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef> { |
| 231 | + match array.data_type() { |
| 232 | + DataType::UInt8 => filter_array!(array, filter, UInt8Array), |
| 233 | + DataType::UInt16 => filter_array!(array, filter, UInt16Array), |
| 234 | + DataType::UInt32 => filter_array!(array, filter, UInt32Array), |
| 235 | + DataType::UInt64 => filter_array!(array, filter, UInt64Array), |
| 236 | + DataType::Int8 => filter_array!(array, filter, Int8Array), |
| 237 | + DataType::Int16 => filter_array!(array, filter, Int16Array), |
| 238 | + DataType::Int32 => filter_array!(array, filter, Int32Array), |
| 239 | + DataType::Int64 => filter_array!(array, filter, Int64Array), |
| 240 | + DataType::Float32 => filter_array!(array, filter, Float32Array), |
| 241 | + DataType::Float64 => filter_array!(array, filter, Float64Array), |
| 242 | + DataType::Boolean => filter_array!(array, filter, BooleanArray), |
| 243 | + DataType::Utf8 => { |
| 244 | + let b = array.as_any().downcast_ref::<BinaryArray>().unwrap(); |
| 245 | + let mut values: Vec<&[u8]> = Vec::with_capacity(b.len()); |
| 246 | + for i in 0..b.len() { |
| 247 | + if filter.value(i) { |
| 248 | + values.push(b.value(i)); |
| 249 | + } |
| 250 | + } |
| 251 | + Ok(Arc::new(BinaryArray::from(values))) |
| 252 | + } |
| 253 | + other => Err(ArrowError::ComputeError(format!( |
| 254 | + "filter not supported for {:?}", |
| 255 | + other |
| 256 | + ))), |
| 257 | + } |
| 258 | +} |
| 259 | + |
| 260 | +macro_rules! limit_array { |
| 261 | + ($array:expr, $num_elements:expr, $array_type:ident) => {{ |
| 262 | + let b = $array.as_any().downcast_ref::<$array_type>().unwrap(); |
| 263 | + let mut builder = $array_type::builder($num_elements); |
| 264 | + for i in 0..$num_elements { |
| 265 | + if b.is_null(i) { |
| 266 | + builder.append_null()?; |
| 267 | + } else { |
| 268 | + builder.append_value(b.value(i))?; |
| 269 | + } |
| 270 | + } |
| 271 | + Ok(Arc::new(builder.finish())) |
| 272 | + }}; |
| 273 | +} |
| 274 | + |
| 275 | +/// Returns the array, taking only the number of elements specified |
| 276 | +/// |
| 277 | +/// Returns the whole array if the number of elements specified is larger than the length of the array |
| 278 | +pub fn limit(array: &Array, num_elements: usize) -> Result<ArrayRef> { |
| 279 | + let num_elements_safe: usize = cmp::min(array.len(), num_elements); |
| 280 | + |
| 281 | + match array.data_type() { |
| 282 | + DataType::UInt8 => limit_array!(array, num_elements_safe, UInt8Array), |
| 283 | + DataType::UInt16 => limit_array!(array, num_elements_safe, UInt16Array), |
| 284 | + DataType::UInt32 => limit_array!(array, num_elements_safe, UInt32Array), |
| 285 | + DataType::UInt64 => limit_array!(array, num_elements_safe, UInt64Array), |
| 286 | + DataType::Int8 => limit_array!(array, num_elements_safe, Int8Array), |
| 287 | + DataType::Int16 => limit_array!(array, num_elements_safe, Int16Array), |
| 288 | + DataType::Int32 => limit_array!(array, num_elements_safe, Int32Array), |
| 289 | + DataType::Int64 => limit_array!(array, num_elements_safe, Int64Array), |
| 290 | + DataType::Float32 => limit_array!(array, num_elements_safe, Float32Array), |
| 291 | + DataType::Float64 => limit_array!(array, num_elements_safe, Float64Array), |
| 292 | + DataType::Boolean => limit_array!(array, num_elements_safe, BooleanArray), |
| 293 | + DataType::Utf8 => { |
| 294 | + let b = array.as_any().downcast_ref::<BinaryArray>().unwrap(); |
| 295 | + let mut values: Vec<&[u8]> = Vec::with_capacity(num_elements_safe); |
| 296 | + for i in 0..num_elements_safe { |
| 297 | + values.push(b.value(i)); |
| 298 | + } |
| 299 | + Ok(Arc::new(BinaryArray::from(values))) |
| 300 | + } |
| 301 | + other => Err(ArrowError::ComputeError(format!( |
| 302 | + "limit not supported for {:?}", |
| 303 | + other |
| 304 | + ))), |
| 305 | + } |
| 306 | +} |
| 307 | + |
207 | 308 | #[cfg(test)] |
208 | 309 | mod tests { |
209 | 310 | use super::*; |
@@ -358,4 +459,80 @@ mod tests { |
358 | 459 | assert_eq!(5, min(&a).unwrap()); |
359 | 460 | assert_eq!(9, max(&a).unwrap()); |
360 | 461 | } |
| 462 | + |
| 463 | + #[test] |
| 464 | + fn test_filter_array() { |
| 465 | + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); |
| 466 | + let b = BooleanArray::from(vec![true, false, false, true, false]); |
| 467 | + let c = filter(&a, &b).unwrap(); |
| 468 | + let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
| 469 | + assert_eq!(2, d.len()); |
| 470 | + assert_eq!(5, d.value(0)); |
| 471 | + assert_eq!(8, d.value(1)); |
| 472 | + } |
| 473 | + |
| 474 | + #[test] |
| 475 | + fn test_filter_binary_array() { |
| 476 | + let a = BinaryArray::from(vec!["hello", " ", "world", "!"]); |
| 477 | + let b = BooleanArray::from(vec![true, false, true, false]); |
| 478 | + let c = filter(&a, &b).unwrap(); |
| 479 | + let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap(); |
| 480 | + assert_eq!(2, d.len()); |
| 481 | + assert_eq!("hello", d.get_string(0)); |
| 482 | + assert_eq!("world", d.get_string(1)); |
| 483 | + } |
| 484 | + |
| 485 | + #[test] |
| 486 | + fn test_filter_array_with_null() { |
| 487 | + let a = Int32Array::from(vec![Some(5), None]); |
| 488 | + let b = BooleanArray::from(vec![false, true]); |
| 489 | + let c = filter(&a, &b).unwrap(); |
| 490 | + let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
| 491 | + assert_eq!(1, d.len()); |
| 492 | + assert_eq!(true, d.is_null(0)); |
| 493 | + } |
| 494 | + |
| 495 | + #[test] |
| 496 | + fn test_limit_array() { |
| 497 | + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); |
| 498 | + let b = limit(&a, 3).unwrap(); |
| 499 | + let c = b.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
| 500 | + assert_eq!(3, c.len()); |
| 501 | + assert_eq!(5, c.value(0)); |
| 502 | + assert_eq!(6, c.value(1)); |
| 503 | + assert_eq!(7, c.value(2)); |
| 504 | + } |
| 505 | + |
| 506 | + #[test] |
| 507 | + fn test_limit_binary_array() { |
| 508 | + let a = BinaryArray::from(vec!["hello", " ", "world", "!"]); |
| 509 | + let b = limit(&a, 2).unwrap(); |
| 510 | + let c = b.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap(); |
| 511 | + assert_eq!(2, c.len()); |
| 512 | + assert_eq!("hello", c.get_string(0)); |
| 513 | + assert_eq!(" ", c.get_string(1)); |
| 514 | + } |
| 515 | + |
| 516 | + #[test] |
| 517 | + fn test_limit_array_with_null() { |
| 518 | + let a = Int32Array::from(vec![None, Some(5)]); |
| 519 | + let b = limit(&a, 1).unwrap(); |
| 520 | + let c = b.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
| 521 | + assert_eq!(1, c.len()); |
| 522 | + assert_eq!(true, c.is_null(0)); |
| 523 | + } |
| 524 | + |
| 525 | + #[test] |
| 526 | + fn test_limit_array_with_limit_too_large() { |
| 527 | + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); |
| 528 | + let b = limit(&a, 6).unwrap(); |
| 529 | + let c = b.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
| 530 | + |
| 531 | + assert_eq!(5, c.len()); |
| 532 | + assert_eq!(a.value(0), c.value(0)); |
| 533 | + assert_eq!(a.value(1), c.value(1)); |
| 534 | + assert_eq!(a.value(2), c.value(2)); |
| 535 | + assert_eq!(a.value(3), c.value(3)); |
| 536 | + assert_eq!(a.value(4), c.value(4)); |
| 537 | + } |
361 | 538 | } |
0 commit comments