Skip to content

Commit

Permalink
feat: Improve performance of thetasketch dinstinct (#1102)
Browse files Browse the repository at this point in the history
## Rationale
Our impl for `thetasketch_dinstinct` udf is so rough, that leading to
the bad performance.
The aggregate stage cost `17s` in our production...

In this pr, I refactor the impl `thetasketch_dinstinct` for the better
performance, now the aggregate stage will just cost `6s` for the same
sample sql.

## Detailed Changes
+ Update `hyperloglog` crate for better performance.
+ Use `DatumView` instead of `DfScalarValue` in `HllDistinct` to
eliminate the unnecessary clone.
+ Fix related test.

## Test Plan
Test by new ut and integration.
  • Loading branch information
Rachelint authored Jul 25, 2023
1 parent 4f37dbf commit da7a799
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 68 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

103 changes: 103 additions & 0 deletions common_types/src/datum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::{hex, string::StringBytes, time::Timestamp};

const DATE_FORMAT: &str = "%Y-%m-%d";
const TIME_FORMAT: &str = "%H:%M:%S%.3f";
const NULL_VALUE_FOR_HASH: u128 = u128::MAX;

#[derive(Debug, Snafu)]
pub enum Error {
Expand Down Expand Up @@ -90,6 +91,25 @@ pub enum Error {

pub type Result<T> = std::result::Result<T, Error>;

// Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for
// floats directly we have to do it through type wrapper
// Fork from datafusion:
// https://github.com/apache/arrow-datafusion/blob/1a0542acbc01e5243471ae0fc3586c2f1f40013b/datafusion/common/src/scalar.rs#L1493
struct Fl<T>(T);

macro_rules! hash_float_value {
($(($t:ty, $i:ty)),+) => {
$(impl std::hash::Hash for Fl<$t> {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write(&<$i>::from_ne_bytes(self.0.to_ne_bytes()).to_ne_bytes())
}
})+
};
}

hash_float_value!((f64, u64), (f32, u32));

// FIXME(yingwen): How to handle timezone?

/// Data type of datum
Expand Down Expand Up @@ -1138,6 +1158,37 @@ impl<'a> DatumView<'a> {
}
}
}

pub fn as_str(&self) -> Option<&str> {
match self {
DatumView::String(v) => Some(v),
_ => None,
}
}
}

impl<'a> std::hash::Hash for DatumView<'a> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
DatumView::Null => NULL_VALUE_FOR_HASH.hash(state),
DatumView::Timestamp(v) => v.hash(state),
DatumView::Double(v) => Fl(*v).hash(state),
DatumView::Float(v) => Fl(*v).hash(state),
DatumView::Varbinary(v) => v.hash(state),
DatumView::String(v) => v.hash(state),
DatumView::UInt64(v) => v.hash(state),
DatumView::UInt32(v) => v.hash(state),
DatumView::UInt16(v) => v.hash(state),
DatumView::UInt8(v) => v.hash(state),
DatumView::Int64(v) => v.hash(state),
DatumView::Int32(v) => v.hash(state),
DatumView::Int16(v) => v.hash(state),
DatumView::Int8(v) => v.hash(state),
DatumView::Boolean(v) => v.hash(state),
DatumView::Date(v) => v.hash(state),
DatumView::Time(v) => v.hash(state),
}
}
}

impl DatumKind {
Expand Down Expand Up @@ -1359,6 +1410,11 @@ impl From<DatumKind> for DataType {

#[cfg(test)]
mod tests {
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};

use super::*;

#[test]
Expand Down Expand Up @@ -1581,4 +1637,51 @@ mod tests {
}
}
}

fn get_hash<V: Hash>(v: &V) -> u64 {
let mut hasher = DefaultHasher::new();
v.hash(&mut hasher);
hasher.finish()
}

macro_rules! assert_datum_view_hash {
($v:expr, $Kind: ident) => {
let expected = get_hash(&DatumView::$Kind($v));
let actual = get_hash(&$v);
assert_eq!(expected, actual);
};
}

#[test]
fn test_hash() {
assert_datum_view_hash!(Timestamp::new(42), Timestamp);
assert_datum_view_hash!(42_i32, Date);
assert_datum_view_hash!(424_i64, Time);
assert_datum_view_hash!(b"abcde", Varbinary);
assert_datum_view_hash!("12345", String);
assert_datum_view_hash!(42424242_u64, UInt64);
assert_datum_view_hash!(424242_u32, UInt32);
assert_datum_view_hash!(4242_u16, UInt16);
assert_datum_view_hash!(42_u8, UInt8);
assert_datum_view_hash!(-42424242_i64, Int64);
assert_datum_view_hash!(-42424242_i32, Int32);
assert_datum_view_hash!(-4242_i16, Int16);
assert_datum_view_hash!(-42_i8, Int8);
assert_datum_view_hash!(true, Boolean);

// Null case.
let null_expected = get_hash(&NULL_VALUE_FOR_HASH);
let null_actual = get_hash(&DatumView::Null);
assert_eq!(null_expected, null_actual);

// Float case.
let float_expected = get_hash(&Fl(42.0_f32));
let float_actual = get_hash(&DatumView::Float(42.0));
assert_eq!(float_expected, float_actual);

// Double case.
let double_expected = get_hash(&Fl(-42.0_f64));
let double_actual = get_hash(&DatumView::Double(-42.0));
assert_eq!(double_expected, double_actual);
}
}
2 changes: 1 addition & 1 deletion df_operator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ chrono = { workspace = true }
common_types = { workspace = true }
datafusion = { workspace = true }
generic_error = { workspace = true }
hyperloglog = { git = "https://github.com/jedisct1/rust-hyperloglog.git", rev = "ed1b9b915072ba90c6b93fbfbba30c03215ba682", features = ["with_serde"] }
hyperloglog = { git = "https://github.com/jedisct1/rust-hyperloglog.git", rev = "425487ce910f26636fbde8c4d640b538431aad50", features = ["with_serde"] }
macros = { workspace = true }
smallvec = { workspace = true }
snafu = { workspace = true }
72 changes: 41 additions & 31 deletions df_operator/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use std::{fmt, ops::Deref};

use arrow::array::ArrayRef as DfArrayRef;
use common_types::column::ColumnBlock;
use datafusion::{
error::{DataFusionError, Result as DfResult},
physical_plan::Accumulator as DfAccumulator,
Expand All @@ -14,7 +15,7 @@ use generic_error::GenericError;
use macros::define_result;
use snafu::Snafu;

use crate::functions::{ScalarValue, ScalarValueRef};
use crate::functions::ScalarValue;

#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
Expand All @@ -28,6 +29,7 @@ pub enum Error {

define_result!(Error);

// TODO: Use `Datum` rather than `ScalarValue`.
pub struct State(Vec<DfScalarValue>);

impl State {
Expand All @@ -43,24 +45,20 @@ impl From<ScalarValue> for State {
}
}

pub struct Input<'a>(&'a [DfScalarValue]);
pub struct Input<'a>(&'a [ColumnBlock]);

impl<'a> Input<'a> {
pub fn iter(&self) -> impl Iterator<Item = ScalarValueRef> {
self.0.iter().map(ScalarValueRef::from)
pub fn num_columns(&self) -> usize {
self.0.len()
}

pub fn len(&self) -> usize {
self.0.len()
pub fn column(&self, col_idx: usize) -> Option<&ColumnBlock> {
self.0.get(col_idx)
}

pub fn is_empty(&self) -> bool {
self.0.is_empty()
}

pub fn value(&self, index: usize) -> ScalarValueRef {
ScalarValueRef::from(&self.0[index])
}
}

pub struct StateRef<'a>(Input<'a>);
Expand All @@ -78,22 +76,24 @@ impl<'a> Deref for StateRef<'a> {
///
/// An accumulator knows how to:
/// * update its state from inputs via `update`
/// * convert its internal state to a vector of scalar values
/// * convert its internal state to column blocks
/// * update its state from multiple accumulators' states via `merge`
/// * compute the final value from its internal state via `evaluate`
pub trait Accumulator: Send + Sync + fmt::Debug {
/// Returns the state of the accumulator at the end of the accumulation.
// in the case of an average on which we track `sum` and `n`, this function
// should return a vector of two values, sum and n.
// TODO: Use `Datum` rather than `ScalarValue`.
fn state(&self) -> Result<State>;

/// updates the accumulator's state from a vector of scalars.
/// updates the accumulator's state from column blocks.
fn update(&mut self, values: Input) -> Result<()>;

/// updates the accumulator's state from a vector of scalars.
/// updates the accumulator's state from column blocks.
fn merge(&mut self, states: StateRef) -> Result<()>;

/// returns its value based on its current state.
// TODO: Use `Datum` rather than `ScalarValue`.
fn evaluate(&self) -> Result<ScalarValue>;
}

Expand All @@ -120,33 +120,43 @@ impl<T: Accumulator> DfAccumulator for ToDfAccumulator<T> {
if values.is_empty() {
return Ok(());
};
(0..values[0].len()).try_for_each(|index| {
let v = values
.iter()
.map(|array| DfScalarValue::try_from_array(array, index))
.collect::<DfResult<Vec<DfScalarValue>>>()?;
let input = Input(&v);

self.accumulator.update(input).map_err(|e| {
DataFusionError::Execution(format!("Accumulator failed to update, err:{e}"))

let column_blocks = values
.iter()
.map(|array| {
ColumnBlock::try_cast_arrow_array_ref(array).map_err(|e| {
DataFusionError::Execution(format!(
"Accumulator failed to cast arrow array to column block, column, err:{e}"
))
})
})
.collect::<DfResult<Vec<_>>>()?;

let input = Input(&column_blocks);
self.accumulator.update(input).map_err(|e| {
DataFusionError::Execution(format!("Accumulator failed to update, err:{e}"))
})
}

fn merge_batch(&mut self, states: &[DfArrayRef]) -> DfResult<()> {
if states.is_empty() {
return Ok(());
};
(0..states[0].len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| DfScalarValue::try_from_array(array, index))
.collect::<DfResult<Vec<DfScalarValue>>>()?;
let state_ref = StateRef(Input(&v));

self.accumulator.merge(state_ref).map_err(|e| {
DataFusionError::Execution(format!("Accumulator failed to merge, err:{e}"))

let column_blocks = states
.iter()
.map(|array| {
ColumnBlock::try_cast_arrow_array_ref(array).map_err(|e| {
DataFusionError::Execution(format!(
"Accumulator failed to cast arrow array to column block, column, err:{e}"
))
})
})
.collect::<DfResult<Vec<_>>>()?;

let state_ref = StateRef(Input(&column_blocks));
self.accumulator.merge(state_ref).map_err(|e| {
DataFusionError::Execution(format!("Accumulator failed to merge, err:{e}"))
})
}

Expand Down
54 changes: 39 additions & 15 deletions df_operator/src/udfs/thetasketch_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,27 @@ struct HllDistinct {
hll: HyperLogLog,
}

// TODO(yingwen): Avoid base64 encode/decode if datafusion supports converting
// binary datatype to scalarvalue.
// TODO: maybe we can remove base64 encoding?
impl HllDistinct {
fn merge_impl(&mut self, states: StateRef) -> Result<()> {
// The states are serialize from hll.
ensure!(states.len() == 1, InvalidStateLen);
let value_ref = states.value(0);
let hll_string = value_ref.as_str().context(StateNotString)?;
let hll_bytes = base64::decode(hll_string).context(DecodeBase64)?;
// Try to deserialize the hll.
let hll = bincode::deserialize(&hll_bytes).context(DecodeHll)?;

// Merge the hll, note that the two hlls must created or serialized from the
// same template hll.
self.hll.merge(&hll);
ensure!(states.num_columns() == 1, InvalidStateLen);
let merged_col = states.column(0).unwrap();

let num_rows = merged_col.num_rows();
for row_idx in 0..num_rows {
let datum = merged_col.datum_view(row_idx);
// Try to deserialize the hll.
let hll_string = datum.as_str().context(StateNotString)?;
let hll_bytes = base64::decode(hll_string).context(DecodeBase64)?;
// Try to deserialize the hll.
let hll = bincode::deserialize(&hll_bytes).context(DecodeHll)?;

// Merge the hll, note that the two hlls must created or serialized from the
// same template hll.
self.hll.merge(&hll);
}

Ok(())
}
Expand All @@ -132,6 +138,7 @@ impl fmt::Debug for HllDistinct {
}

impl Accumulator for HllDistinct {
// TODO: maybe we can remove base64 encoding?
fn state(&self) -> aggregate::Result<State> {
// Serialize `self.hll` to bytes.
let buf = bincode::serialize(&self.hll).box_err().context(GetState)?;
Expand All @@ -142,10 +149,27 @@ impl Accumulator for HllDistinct {
Ok(State::from(ScalarValue::from(hll_string)))
}

fn update(&mut self, values: Input) -> aggregate::Result<()> {
for value_ref in values.iter() {
// Insert value into hll.
self.hll.insert(&value_ref);
fn update(&mut self, input: Input) -> aggregate::Result<()> {
if input.is_empty() {
return Ok(());
}

// Has found it not empty, so we can unwrap here.
let first_col = input.column(0).unwrap();
let num_rows = first_col.num_rows();
if num_rows == 0 {
return Ok(());
}

// Loop over the datums in the column blocks, insert them into hll.
let num_cols = input.num_columns();
for col_idx in 0..num_cols {
let col = input.column(col_idx).unwrap();
for row_idx in 0..num_rows {
let datum = col.datum_view(row_idx);
// Insert datum into hll.
self.hll.insert(&datum);
}
}

Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ affected_rows: 400
SELECT thetasketch_distinct(`value`) FROM `02_function_thetasketch_distinct_table`;

thetasketch_distinct(02_function_thetasketch_distinct_table.value),
UInt64(147),
UInt64(148),


SELECT
Expand All @@ -439,7 +439,7 @@ ORDER BY
`arch` DESC;

arch,thetasketch_distinct(02_function_thetasketch_distinct_table.value),
String("x86"),UInt64(115),
String("x86"),UInt64(113),
String("arm"),UInt64(117),


Expand Down
Loading

0 comments on commit da7a799

Please sign in to comment.