From cf04021574171eff6e6df257edf12e9282606653 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Thu, 12 Aug 2021 23:50:39 +0200 Subject: [PATCH] strenght_reduce division/remainder (#275) --- .gitignore | 1 + Cargo.toml | 9 ++- benches/arithmetic_kernels.rs | 52 +++++++++++++++ src/compute/arithmetics/basic/div.rs | 95 +++++++++++++++++++++++++-- src/compute/arithmetics/basic/rem.rs | 96 ++++++++++++++++++++++++++-- src/compute/arithmetics/mod.rs | 5 +- 6 files changed, 247 insertions(+), 11 deletions(-) create mode 100644 benches/arithmetic_kernels.rs diff --git a/.gitignore b/.gitignore index c6d763755bf..39564fe2b50 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ Cargo.lock fixtures settings.json dev/ +.idea/ diff --git a/Cargo.toml b/Cargo.toml index 253e4ad7a77..9c132bb2838 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,9 @@ ahash = { version = "0.7", optional = true } parquet2 = { version = "0.3", optional = true, default_features = false, features = ["stream"] } +# for division/remainder optimization at runtime +strength_reduce = { version = "0.2", optional = true } + [dev-dependencies] rand = "0.8" criterion = "0.3" @@ -98,7 +101,7 @@ io_parquet_compression = [ io_json_integration = ["io_json", "hex"] io_print = ["comfy-table"] # the compute kernels. Disabling this significantly reduces compile time. -compute = [] +compute = ["strength_reduce"] # base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format. io_parquet = ["parquet2", "io_ipc", "base64", "futures"] benchmarks = ["rand"] @@ -167,3 +170,7 @@ harness = false [[bench]] name = "write_ipc" harness = false + +[[bench]] +name = "arithmetic_kernels" +harness = false diff --git a/benches/arithmetic_kernels.rs b/benches/arithmetic_kernels.rs new file mode 100644 index 00000000000..f2ef5a9450b --- /dev/null +++ b/benches/arithmetic_kernels.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +use arrow2::array::*; +use arrow2::util::bench_util::*; +use arrow2::{ + compute::arithmetics::basic::div::div_scalar, datatypes::DataType, types::NativeType, +}; +use num::NumCast; +use std::ops::Div; + +fn bench_div_scalar(lhs: &PrimitiveArray, rhs: &T) +where + T: NativeType + Div + NumCast, +{ + criterion::black_box(div_scalar(lhs, rhs)); +} + +fn add_benchmark(c: &mut Criterion) { + let size = 65536; + let arr = create_primitive_array_with_seed::(size, DataType::UInt64, 0.0, 43); + + c.bench_function("divide_scalar 4", |b| { + // 4 is a very fast optimizable divisor + b.iter(|| bench_div_scalar(&arr, &4)) + }); + c.bench_function("divide_scalar prime", |b| { + // large prime number that is probably harder to simplify + b.iter(|| bench_div_scalar(&arr, &524287)) + }); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/src/compute/arithmetics/basic/div.rs b/src/compute/arithmetics/basic/div.rs index 8556c995a9a..8e83bb1953e 100644 --- a/src/compute/arithmetics/basic/div.rs +++ b/src/compute/arithmetics/basic/div.rs @@ -1,8 +1,9 @@ //! Definition of basic div operations with primitive arrays use std::ops::Div; -use num::{CheckedDiv, Zero}; +use num::{CheckedDiv, NumCast, Zero}; +use crate::datatypes::DataType; use crate::{ array::{Array, PrimitiveArray}, compute::{ @@ -12,6 +13,9 @@ use crate::{ error::{ArrowError, Result}, types::NativeType, }; +use strength_reduce::{ + StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, +}; /// Divides two primitive arrays with the same type. /// Panics if the divisor is zero of one pair of values overflows. @@ -109,10 +113,72 @@ where /// ``` pub fn div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where - T: NativeType + Div, + T: NativeType + Div + NumCast, { let rhs = *rhs; - unary(lhs, |a| a / rhs, lhs.data_type().clone()) + match T::DATA_TYPE { + DataType::UInt64 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u64().unwrap(); + + let reduced_div = StrengthReducedU64::new(rhs); + // Safety: we just proved that `lhs` is `PrimitiveArray` which means that + // T = u64 + unsafe { + std::mem::transmute::, PrimitiveArray>(unary( + lhs, + |a| a / reduced_div, + lhs.data_type().clone(), + )) + } + } + DataType::UInt32 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u32().unwrap(); + + let reduced_div = StrengthReducedU32::new(rhs); + // Safety: we just proved that `lhs` is `PrimitiveArray` which means that + // T = u32 + unsafe { + std::mem::transmute::, PrimitiveArray>(unary( + lhs, + |a| a / reduced_div, + lhs.data_type().clone(), + )) + } + } + DataType::UInt16 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u16().unwrap(); + + let reduced_div = StrengthReducedU16::new(rhs); + // Safety: we just proved that `lhs` is `PrimitiveArray` which means that + // T = u16 + unsafe { + std::mem::transmute::, PrimitiveArray>(unary( + lhs, + |a| a / reduced_div, + lhs.data_type().clone(), + )) + } + } + DataType::UInt8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u8().unwrap(); + + let reduced_div = StrengthReducedU8::new(rhs); + // Safety: we just proved that `lhs` is `PrimitiveArray` which means that + // T = u8 + unsafe { + std::mem::transmute::, PrimitiveArray>(unary( + lhs, + |a| a / reduced_div, + lhs.data_type().clone(), + )) + } + } + _ => unary(lhs, |a| a / rhs, lhs.data_type().clone()), + } } /// Checked division of a primitive array of type T by a scalar T. If the @@ -141,7 +207,7 @@ where // Implementation of ArrayDiv trait for PrimitiveArrays with a scalar impl ArrayDiv for PrimitiveArray where - T: NativeType + Div + NotI128, + T: NativeType + Div + NotI128 + NumCast, { type Output = Self; @@ -226,6 +292,27 @@ mod tests { // Trait testing let result = a.div(&1i32).unwrap(); assert_eq!(result, expected); + + // check the strength reduced branches + let a = UInt64Array::from(&[None, Some(6), None, Some(6)]); + let result = div_scalar(&a, &1u64); + let expected = UInt64Array::from(&[None, Some(6), None, Some(6)]); + assert_eq!(result, expected); + + let a = UInt32Array::from(&[None, Some(6), None, Some(6)]); + let result = div_scalar(&a, &1u32); + let expected = UInt32Array::from(&[None, Some(6), None, Some(6)]); + assert_eq!(result, expected); + + let a = UInt16Array::from(&[None, Some(6), None, Some(6)]); + let result = div_scalar(&a, &1u16); + let expected = UInt16Array::from(&[None, Some(6), None, Some(6)]); + assert_eq!(result, expected); + + let a = UInt8Array::from(&[None, Some(6), None, Some(6)]); + let result = div_scalar(&a, &1u8); + let expected = UInt8Array::from(&[None, Some(6), None, Some(6)]); + assert_eq!(result, expected); } #[test] diff --git a/src/compute/arithmetics/basic/rem.rs b/src/compute/arithmetics/basic/rem.rs index 8faa1426b75..68f2088a75c 100644 --- a/src/compute/arithmetics/basic/rem.rs +++ b/src/compute/arithmetics/basic/rem.rs @@ -1,7 +1,8 @@ use std::ops::Rem; -use num::{traits::CheckedRem, Zero}; +use num::{traits::CheckedRem, NumCast, Zero}; +use crate::datatypes::DataType; use crate::{ array::{Array, PrimitiveArray}, compute::{ @@ -11,6 +12,9 @@ use crate::{ error::{ArrowError, Result}, types::NativeType, }; +use strength_reduce::{ + StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, +}; /// Remainder of two primitive arrays with the same type. /// Panics if the divisor is zero of one pair of values overflows. @@ -106,10 +110,73 @@ where /// ``` pub fn rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where - T: NativeType + Rem, + T: NativeType + Rem + NumCast, { let rhs = *rhs; - unary(lhs, |a| a % rhs, lhs.data_type().clone()) + + match T::DATA_TYPE { + DataType::UInt64 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u64().unwrap(); + + let reduced_rem = StrengthReducedU64::new(rhs); + // Safety: we just proved that `lhs` is `PrimitiveArray` which means that + // T = u64 + unsafe { + std::mem::transmute::, PrimitiveArray>(unary( + lhs, + |a| a % reduced_rem, + lhs.data_type().clone(), + )) + } + } + DataType::UInt32 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u32().unwrap(); + + let reduced_rem = StrengthReducedU32::new(rhs); + // Safety: we just proved that `lhs` is `PrimitiveArray` which means that + // T = u32 + unsafe { + std::mem::transmute::, PrimitiveArray>(unary( + lhs, + |a| a % reduced_rem, + lhs.data_type().clone(), + )) + } + } + DataType::UInt16 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u16().unwrap(); + + let reduced_rem = StrengthReducedU16::new(rhs); + // Safety: we just proved that `lhs` is `PrimitiveArray` which means that + // T = u16 + unsafe { + std::mem::transmute::, PrimitiveArray>(unary( + lhs, + |a| a % reduced_rem, + lhs.data_type().clone(), + )) + } + } + DataType::UInt8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u8().unwrap(); + + let reduced_rem = StrengthReducedU8::new(rhs); + // Safety: we just proved that `lhs` is `PrimitiveArray` which means that + // T = u8 + unsafe { + std::mem::transmute::, PrimitiveArray>(unary( + lhs, + |a| a % reduced_rem, + lhs.data_type().clone(), + )) + } + } + _ => unary(lhs, |a| a % rhs, lhs.data_type().clone()), + } } /// Checked remainder of a primitive array of type T by a scalar T. If the @@ -137,7 +204,7 @@ where impl ArrayRem for PrimitiveArray where - T: NativeType + Rem + NotI128, + T: NativeType + Rem + NotI128 + NumCast, { type Output = Self; @@ -221,6 +288,27 @@ mod tests { // Trait testing let result = a.rem(&2i32).unwrap(); assert_eq!(result, expected); + + // check the strength reduced branches + let a = UInt64Array::from(&[None, Some(6), None, Some(5)]); + let result = rem_scalar(&a, &2u64); + let expected = UInt64Array::from(&[None, Some(0), None, Some(1)]); + assert_eq!(result, expected); + + let a = UInt32Array::from(&[None, Some(6), None, Some(5)]); + let result = rem_scalar(&a, &2u32); + let expected = UInt32Array::from(&[None, Some(0), None, Some(1)]); + assert_eq!(result, expected); + + let a = UInt16Array::from(&[None, Some(6), None, Some(5)]); + let result = rem_scalar(&a, &2u16); + let expected = UInt16Array::from(&[None, Some(0), None, Some(1)]); + assert_eq!(result, expected); + + let a = UInt8Array::from(&[None, Some(6), None, Some(5)]); + let result = rem_scalar(&a, &2u8); + let expected = UInt8Array::from(&[None, Some(0), None, Some(1)]); + assert_eq!(result, expected); } #[test] diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index 496f02d72f3..ad7eb846c80 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -61,7 +61,7 @@ pub mod time; use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; -use num::Zero; +use num::{NumCast, Zero}; use crate::datatypes::{DataType, TimeUnit}; use crate::error::{ArrowError, Result}; @@ -265,7 +265,8 @@ where + Add + Sub + Mul - + Rem, + + Rem + + NumCast, { match op { Operator::Add => Ok(basic::add::add_scalar(lhs, rhs)),