Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
strenght_reduce division/remainder (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 12, 2021
1 parent c7e1a4d commit cf04021
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Cargo.lock
fixtures
settings.json
dev/
.idea/
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ ahash = { version = "0.7", optional = true }

parquet2 = { version = "0.3", optional = true, default_features = false, features = ["stream"] }

# for division/remainder optimization at runtime
strength_reduce = { version = "0.2", optional = true }

[dev-dependencies]
rand = "0.8"
criterion = "0.3"
Expand Down Expand Up @@ -98,7 +101,7 @@ io_parquet_compression = [
io_json_integration = ["io_json", "hex"]
io_print = ["comfy-table"]
# the compute kernels. Disabling this significantly reduces compile time.
compute = []
compute = ["strength_reduce"]
# base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format.
io_parquet = ["parquet2", "io_ipc", "base64", "futures"]
benchmarks = ["rand"]
Expand Down Expand Up @@ -167,3 +170,7 @@ harness = false
[[bench]]
name = "write_ipc"
harness = false

[[bench]]
name = "arithmetic_kernels"
harness = false
52 changes: 52 additions & 0 deletions benches/arithmetic_kernels.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#[macro_use]
extern crate criterion;
use criterion::Criterion;

use arrow2::array::*;
use arrow2::util::bench_util::*;
use arrow2::{
compute::arithmetics::basic::div::div_scalar, datatypes::DataType, types::NativeType,
};
use num::NumCast;
use std::ops::Div;

fn bench_div_scalar<T>(lhs: &PrimitiveArray<T>, rhs: &T)
where
T: NativeType + Div<Output = T> + NumCast,
{
criterion::black_box(div_scalar(lhs, rhs));
}

fn add_benchmark(c: &mut Criterion) {
let size = 65536;
let arr = create_primitive_array_with_seed::<u64>(size, DataType::UInt64, 0.0, 43);

c.bench_function("divide_scalar 4", |b| {
// 4 is a very fast optimizable divisor
b.iter(|| bench_div_scalar(&arr, &4))
});
c.bench_function("divide_scalar prime", |b| {
// large prime number that is probably harder to simplify
b.iter(|| bench_div_scalar(&arr, &524287))
});
}

criterion_group!(benches, add_benchmark);
criterion_main!(benches);
95 changes: 91 additions & 4 deletions src/compute/arithmetics/basic/div.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Definition of basic div operations with primitive arrays
use std::ops::Div;

use num::{CheckedDiv, Zero};
use num::{CheckedDiv, NumCast, Zero};

use crate::datatypes::DataType;
use crate::{
array::{Array, PrimitiveArray},
compute::{
Expand All @@ -12,6 +13,9 @@ use crate::{
error::{ArrowError, Result},
types::NativeType,
};
use strength_reduce::{
StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8,
};

/// Divides two primitive arrays with the same type.
/// Panics if the divisor is zero of one pair of values overflows.
Expand Down Expand Up @@ -109,10 +113,72 @@ where
/// ```
pub fn div_scalar<T>(lhs: &PrimitiveArray<T>, rhs: &T) -> PrimitiveArray<T>
where
T: NativeType + Div<Output = T>,
T: NativeType + Div<Output = T> + NumCast,
{
let rhs = *rhs;
unary(lhs, |a| a / rhs, lhs.data_type().clone())
match T::DATA_TYPE {
DataType::UInt64 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u64>>().unwrap();
let rhs = rhs.to_u64().unwrap();

let reduced_div = StrengthReducedU64::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u64>` which means that
// T = u64
unsafe {
std::mem::transmute::<PrimitiveArray<u64>, PrimitiveArray<T>>(unary(
lhs,
|a| a / reduced_div,
lhs.data_type().clone(),
))
}
}
DataType::UInt32 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u32>>().unwrap();
let rhs = rhs.to_u32().unwrap();

let reduced_div = StrengthReducedU32::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u32>` which means that
// T = u32
unsafe {
std::mem::transmute::<PrimitiveArray<u32>, PrimitiveArray<T>>(unary(
lhs,
|a| a / reduced_div,
lhs.data_type().clone(),
))
}
}
DataType::UInt16 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u16>>().unwrap();
let rhs = rhs.to_u16().unwrap();

let reduced_div = StrengthReducedU16::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u16>` which means that
// T = u16
unsafe {
std::mem::transmute::<PrimitiveArray<u16>, PrimitiveArray<T>>(unary(
lhs,
|a| a / reduced_div,
lhs.data_type().clone(),
))
}
}
DataType::UInt8 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u8>>().unwrap();
let rhs = rhs.to_u8().unwrap();

let reduced_div = StrengthReducedU8::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u8>` which means that
// T = u8
unsafe {
std::mem::transmute::<PrimitiveArray<u8>, PrimitiveArray<T>>(unary(
lhs,
|a| a / reduced_div,
lhs.data_type().clone(),
))
}
}
_ => unary(lhs, |a| a / rhs, lhs.data_type().clone()),
}
}

/// Checked division of a primitive array of type T by a scalar T. If the
Expand Down Expand Up @@ -141,7 +207,7 @@ where
// Implementation of ArrayDiv trait for PrimitiveArrays with a scalar
impl<T> ArrayDiv<T> for PrimitiveArray<T>
where
T: NativeType + Div<Output = T> + NotI128,
T: NativeType + Div<Output = T> + NotI128 + NumCast,
{
type Output = Self;

Expand Down Expand Up @@ -226,6 +292,27 @@ mod tests {
// Trait testing
let result = a.div(&1i32).unwrap();
assert_eq!(result, expected);

// check the strength reduced branches
let a = UInt64Array::from(&[None, Some(6), None, Some(6)]);
let result = div_scalar(&a, &1u64);
let expected = UInt64Array::from(&[None, Some(6), None, Some(6)]);
assert_eq!(result, expected);

let a = UInt32Array::from(&[None, Some(6), None, Some(6)]);
let result = div_scalar(&a, &1u32);
let expected = UInt32Array::from(&[None, Some(6), None, Some(6)]);
assert_eq!(result, expected);

let a = UInt16Array::from(&[None, Some(6), None, Some(6)]);
let result = div_scalar(&a, &1u16);
let expected = UInt16Array::from(&[None, Some(6), None, Some(6)]);
assert_eq!(result, expected);

let a = UInt8Array::from(&[None, Some(6), None, Some(6)]);
let result = div_scalar(&a, &1u8);
let expected = UInt8Array::from(&[None, Some(6), None, Some(6)]);
assert_eq!(result, expected);
}

#[test]
Expand Down
96 changes: 92 additions & 4 deletions src/compute/arithmetics/basic/rem.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::ops::Rem;

use num::{traits::CheckedRem, Zero};
use num::{traits::CheckedRem, NumCast, Zero};

use crate::datatypes::DataType;
use crate::{
array::{Array, PrimitiveArray},
compute::{
Expand All @@ -11,6 +12,9 @@ use crate::{
error::{ArrowError, Result},
types::NativeType,
};
use strength_reduce::{
StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8,
};

/// Remainder of two primitive arrays with the same type.
/// Panics if the divisor is zero of one pair of values overflows.
Expand Down Expand Up @@ -106,10 +110,73 @@ where
/// ```
pub fn rem_scalar<T>(lhs: &PrimitiveArray<T>, rhs: &T) -> PrimitiveArray<T>
where
T: NativeType + Rem<Output = T>,
T: NativeType + Rem<Output = T> + NumCast,
{
let rhs = *rhs;
unary(lhs, |a| a % rhs, lhs.data_type().clone())

match T::DATA_TYPE {
DataType::UInt64 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u64>>().unwrap();
let rhs = rhs.to_u64().unwrap();

let reduced_rem = StrengthReducedU64::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u64>` which means that
// T = u64
unsafe {
std::mem::transmute::<PrimitiveArray<u64>, PrimitiveArray<T>>(unary(
lhs,
|a| a % reduced_rem,
lhs.data_type().clone(),
))
}
}
DataType::UInt32 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u32>>().unwrap();
let rhs = rhs.to_u32().unwrap();

let reduced_rem = StrengthReducedU32::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u32>` which means that
// T = u32
unsafe {
std::mem::transmute::<PrimitiveArray<u32>, PrimitiveArray<T>>(unary(
lhs,
|a| a % reduced_rem,
lhs.data_type().clone(),
))
}
}
DataType::UInt16 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u16>>().unwrap();
let rhs = rhs.to_u16().unwrap();

let reduced_rem = StrengthReducedU16::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u16>` which means that
// T = u16
unsafe {
std::mem::transmute::<PrimitiveArray<u16>, PrimitiveArray<T>>(unary(
lhs,
|a| a % reduced_rem,
lhs.data_type().clone(),
))
}
}
DataType::UInt8 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u8>>().unwrap();
let rhs = rhs.to_u8().unwrap();

let reduced_rem = StrengthReducedU8::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u8>` which means that
// T = u8
unsafe {
std::mem::transmute::<PrimitiveArray<u8>, PrimitiveArray<T>>(unary(
lhs,
|a| a % reduced_rem,
lhs.data_type().clone(),
))
}
}
_ => unary(lhs, |a| a % rhs, lhs.data_type().clone()),
}
}

/// Checked remainder of a primitive array of type T by a scalar T. If the
Expand Down Expand Up @@ -137,7 +204,7 @@ where

impl<T> ArrayRem<T> for PrimitiveArray<T>
where
T: NativeType + Rem<Output = T> + NotI128,
T: NativeType + Rem<Output = T> + NotI128 + NumCast,
{
type Output = Self;

Expand Down Expand Up @@ -221,6 +288,27 @@ mod tests {
// Trait testing
let result = a.rem(&2i32).unwrap();
assert_eq!(result, expected);

// check the strength reduced branches
let a = UInt64Array::from(&[None, Some(6), None, Some(5)]);
let result = rem_scalar(&a, &2u64);
let expected = UInt64Array::from(&[None, Some(0), None, Some(1)]);
assert_eq!(result, expected);

let a = UInt32Array::from(&[None, Some(6), None, Some(5)]);
let result = rem_scalar(&a, &2u32);
let expected = UInt32Array::from(&[None, Some(0), None, Some(1)]);
assert_eq!(result, expected);

let a = UInt16Array::from(&[None, Some(6), None, Some(5)]);
let result = rem_scalar(&a, &2u16);
let expected = UInt16Array::from(&[None, Some(0), None, Some(1)]);
assert_eq!(result, expected);

let a = UInt8Array::from(&[None, Some(6), None, Some(5)]);
let result = rem_scalar(&a, &2u8);
let expected = UInt8Array::from(&[None, Some(0), None, Some(1)]);
assert_eq!(result, expected);
}

#[test]
Expand Down
5 changes: 3 additions & 2 deletions src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub mod time;

use std::ops::{Add, Div, Mul, Neg, Rem, Sub};

use num::Zero;
use num::{NumCast, Zero};

use crate::datatypes::{DataType, TimeUnit};
use crate::error::{ArrowError, Result};
Expand Down Expand Up @@ -265,7 +265,8 @@ where
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Rem<Output = T>,
+ Rem<Output = T>
+ NumCast,
{
match op {
Operator::Add => Ok(basic::add::add_scalar(lhs, rhs)),
Expand Down

0 comments on commit cf04021

Please sign in to comment.