Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(flow): avg func rewrite to sum/count #3955

Merged
merged 6 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/flow/src/compute/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ mod test {
for now in time_range {
state.set_current_ts(now);
state.run_available_with_schedule(df);
if !state.get_err_collector().is_empty() {
panic!(
"Errors occur: {:?}",
state.get_err_collector().get_all_blocking()
)
}
assert!(state.get_err_collector().is_empty());
if let Some(expected) = expected.get(&now) {
assert_eq!(*output.borrow(), *expected, "at ts={}", now);
Expand Down
102 changes: 100 additions & 2 deletions src/flow/src/compute/render/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,15 +729,113 @@ mod test {
use std::cell::RefCell;
use std::rc::Rc;

use datatypes::data_type::ConcreteDataType;
use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT};
use hydroflow::scheduled::graph::Hydroflow;

use super::*;
use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check};
use crate::compute::state::DataflowState;
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject};
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc};
use crate::repr::{ColumnType, RelationType};

/// select avg(number) from number;
#[test]
fn test_avg_eval() {
let mut df = Hydroflow::new();
let mut state = DataflowState::default();
let mut ctx = harness_test_ctx(&mut df, &mut state);

let rows = vec![
(Row::new(vec![1u32.into()]), 1, 1),
(Row::new(vec![2u32.into()]), 1, 1),
(Row::new(vec![3u32.into()]), 1, 1),
(Row::new(vec![1u32.into()]), 1, 1),
(Row::new(vec![2u32.into()]), 1, 1),
(Row::new(vec![3u32.into()]), 1, 1),
];
let collection = ctx.render_constant(rows.clone());
ctx.insert_global(GlobalId::User(1), collection);

let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(0),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Literal(Value::from(0u32), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
BinaryFunc::DivUInt64,
)),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::int64_datatype(), false),
])),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.project(vec![])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), true),
ColumnType::new(ConcreteDataType::int64_datatype(), true),
])),
),
mfp: MapFilterProject::new(2)
.map(vec![
avg_expr,
// TODO(discord9): optimize mfp so to remove indirect ref
ScalarExpr::Column(2),
])
.unwrap()
.project(vec![3])
.unwrap(),
},
};

let bundle = ctx.render_plan(expected).unwrap();

let output = get_output_handle(&mut ctx, bundle);
drop(ctx);
let expected = BTreeMap::from([(1, vec![(Row::new(vec![2u64.into()]), 1, 1)])]);
run_and_check(&mut state, &mut df, 1..2, expected, output);
}

/// SELECT DISTINCT col FROM table
///
/// table schema:
Expand Down
3 changes: 3 additions & 0 deletions src/flow/src/compute/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ pub struct ErrCollector {
}

impl ErrCollector {
pub fn get_all_blocking(&self) -> Vec<EvalError> {
self.inner.blocking_lock().drain(..).collect_vec()
}
pub async fn get_all(&self) -> Vec<EvalError> {
self.inner.lock().await.drain(..).collect_vec()
}
Expand Down
16 changes: 16 additions & 0 deletions src/flow/src/expr/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,22 @@ impl BinaryFunc {
)
}

pub fn add(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Add, input_type)
}

pub fn sub(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Sub, input_type)
}

pub fn mul(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Mul, input_type)
}

pub fn div(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Div, input_type)
}

/// Get the specialization of the binary function based on the generic function and the input type
pub fn specialization(generic: GenericFn, input_type: ConcreteDataType) -> Result<Self, Error> {
let rule = SPECIALIZATION.get_or_init(|| {
Expand Down
57 changes: 38 additions & 19 deletions src/flow/src/expr/relation/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,44 @@ impl AggregateFunc {

/// Generate signature for each aggregate function
macro_rules! generate_signature {
($value:ident, { $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($con_type:ident,$generic:ident)
),*
]) => {
($value:ident,
{ $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($($arg:ident),*)
),*
]
) => {
match $value {
$($user_arm)*,
$(
Self::$auto_arm => Signature {
input: smallvec![
ConcreteDataType::$con_type(),
ConcreteDataType::$con_type(),
],
output: ConcreteDataType::$con_type(),
generic_fn: GenericFn::$generic,
},
Self::$auto_arm => gen_one_siginature!($($arg),*),
)*
}
};
}

/// Generate one match arm with optional arguments
macro_rules! gen_one_siginature {
(
$con_type:ident, $generic:ident
) => {
Signature {
input: smallvec![ConcreteDataType::$con_type(), ConcreteDataType::$con_type(),],
output: ConcreteDataType::$con_type(),
generic_fn: GenericFn::$generic,
}
};
(
$in_type:ident, $out_type:ident, $generic:ident
) => {
Signature {
input: smallvec![ConcreteDataType::$in_type()],
output: ConcreteDataType::$out_type(),
generic_fn: GenericFn::$generic,
}
};
}

static SPECIALIZATION: OnceLock<HashMap<(GenericFn, ConcreteDataType), AggregateFunc>> =
OnceLock::new();

Expand Down Expand Up @@ -223,6 +240,8 @@ impl AggregateFunc {

/// all concrete datatypes with precision types will be returned with largest possible variant
/// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64`
///
/// TODO(discorcd9): fix signature for sum unsign -> u64 sum signed -> i64
pub fn signature(&self) -> Signature {
generate_signature!(self, {
AggregateFunc::Count => Signature {
Expand Down Expand Up @@ -263,12 +282,12 @@ impl AggregateFunc {
MinTime => (time_second_datatype, Min),
MinDuration => (duration_second_datatype, Min),
MinInterval => (interval_year_month_datatype, Min),
SumInt16 => (int16_datatype, Sum),
SumInt32 => (int32_datatype, Sum),
SumInt64 => (int64_datatype, Sum),
SumUInt16 => (uint16_datatype, Sum),
SumUInt32 => (uint32_datatype, Sum),
SumUInt64 => (uint64_datatype, Sum),
SumInt16 => (int16_datatype, int64_datatype, Sum),
SumInt32 => (int32_datatype, int64_datatype, Sum),
SumInt64 => (int64_datatype, int64_datatype, Sum),
SumUInt16 => (uint16_datatype, uint64_datatype, Sum),
SumUInt32 => (uint32_datatype, uint64_datatype, Sum),
SumUInt64 => (uint64_datatype, uint64_datatype, Sum),
SumFloat32 => (float32_datatype, Sum),
SumFloat64 => (float64_datatype, Sum),
Any => (boolean_datatype, Any),
Expand Down
6 changes: 3 additions & 3 deletions src/flow/src/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub struct TypedPlan {
impl TypedPlan {
/// directly apply a mfp to the plan
pub fn mfp(self, mfp: MapFilterProject) -> Result<Self, Error> {
let new_type = self.typ.apply_mfp(&mfp, &[])?;
let new_type = self.typ.apply_mfp(&mfp)?;
let plan = match self.plan {
Plan::Mfp {
input,
Expand All @@ -68,14 +68,14 @@ impl TypedPlan {
pub fn projection(self, exprs: Vec<TypedExpr>) -> Result<Self, Error> {
let input_arity = self.typ.column_types.len();
let output_arity = exprs.len();
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs
let (exprs, _expr_typs): (Vec<_>, Vec<_>) = exprs
.into_iter()
.map(|TypedExpr { expr, typ }| (expr, typ))
.unzip();
let mfp = MapFilterProject::new(input_arity)
.map(exprs)?
.project(input_arity..input_arity + output_arity)?;
let out_typ = self.typ.apply_mfp(&mfp, &expr_typs)?;
let out_typ = self.typ.apply_mfp(&mfp)?;
// special case for mfp to compose when the plan is already mfp
let plan = match self.plan {
Plan::Mfp {
Expand Down
15 changes: 8 additions & 7 deletions src/flow/src/repr/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ impl RelationType {
/// then new key=`[1]`, new time index=`[0]`
///
/// note that this function will remove empty keys like key=`[]` will be removed
pub fn apply_mfp(&self, mfp: &MapFilterProject, expr_typs: &[ColumnType]) -> Result<Self> {
let all_types = self
.column_types
.iter()
.chain(expr_typs.iter())
.cloned()
.collect_vec();
pub fn apply_mfp(&self, mfp: &MapFilterProject) -> Result<Self> {
let mut all_types = self.column_types.clone();
for expr in &mfp.expressions {
let expr_typ = expr.typ(&self.column_types)?;
all_types.push(expr_typ);
}
let all_types = all_types;
let mfp_out_types = mfp
.projection
.iter()
Expand All @@ -131,6 +131,7 @@ impl RelationType {
})
})
.try_collect()?;

let old_to_new_col = BTreeMap::from_iter(
mfp.projection
.clone()
Expand Down
Loading
Loading