Skip to content

Commit

Permalink
sql: add rangeany binary funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Loiselle committed Jan 25, 2023
1 parent 1c1ed39 commit 6aed5dd
Show file tree
Hide file tree
Showing 6 changed files with 1,234 additions and 576 deletions.
8 changes: 7 additions & 1 deletion src/expr/src/scalar.proto
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,13 @@ message ProtoBinaryFunc {
google.protobuf.Empty mod_uint32 = 170;
google.protobuf.Empty mod_uint64 = 171;
ProtoRangeContainsInner range_contains_elem = 172;
ProtoRangeContainsInner range_contains_range = 173;
bool range_contains_range = 173;
google.protobuf.Empty range_overlaps = 174;
google.protobuf.Empty range_after = 175;
google.protobuf.Empty range_before = 176;
google.protobuf.Empty range_overleft = 177;
google.protobuf.Empty range_overright = 178;
google.protobuf.Empty range_adjacent = 179;
}
}

Expand Down
118 changes: 83 additions & 35 deletions src/expr/src/scalar/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1264,15 +1264,28 @@ where
Datum::from(range.contains_elem(&elem))
}

fn contains_range_range<'a, R: RangeOps<'a>>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a>
where
<R as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
{
let l = a.unwrap_range();
let r = b.unwrap_range();
Datum::from(l.contains_range::<R>(&r))
macro_rules! range_fn {
($fn:expr) => {
paste::paste! {

fn [< range_ $fn >]<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a>
{
let l = a.unwrap_range();
let r = b.unwrap_range();
Datum::from(Range::<Datum<'a>>::$fn(&l, &r))
}
}
};
}

range_fn!(contains_range);
range_fn!(overlaps);
range_fn!(after);
range_fn!(before);
range_fn!(overleft);
range_fn!(overright);
range_fn!(adjacent);

fn eq<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a == b)
}
Expand Down Expand Up @@ -1915,7 +1928,13 @@ pub enum BinaryFunc {
PowerNumeric,
GetByte,
RangeContainsElem { elem_type: ScalarType, rev: bool },
RangeContainsRange { elem_type: ScalarType, rev: bool },
RangeContainsRange { rev: bool },
RangeOverlaps,
RangeAfter,
RangeBefore,
RangeOverleft,
RangeOverright,
RangeAdjacent,
}

impl BinaryFunc {
Expand Down Expand Up @@ -2222,15 +2241,13 @@ impl BinaryFunc {
}
_ => unreachable!(),
}),
BinaryFunc::RangeContainsRange { elem_type, rev: _ } => Ok(match elem_type {
ScalarType::Int32 => eager!(contains_range_range::<i32>),
ScalarType::Int64 => eager!(contains_range_range::<i64>),
ScalarType::Date => eager!(contains_range_range::<Date>),
ScalarType::Numeric { .. } => {
eager!(contains_range_range::<OrderedDecimal<Numeric>>)
}
_ => unreachable!(),
}),
BinaryFunc::RangeContainsRange { rev: _ } => Ok(eager!(range_contains_range)),
BinaryFunc::RangeOverlaps => Ok(eager!(range_overlaps)),
BinaryFunc::RangeAfter => Ok(eager!(range_after)),
BinaryFunc::RangeBefore => Ok(eager!(range_before)),
BinaryFunc::RangeOverleft => Ok(eager!(range_overleft)),
BinaryFunc::RangeOverright => Ok(eager!(range_overright)),
BinaryFunc::RangeAdjacent => Ok(eager!(range_adjacent)),
}
}

Expand Down Expand Up @@ -2384,9 +2401,14 @@ impl BinaryFunc {

GetByte => ScalarType::Int32.nullable(in_nullable),

RangeContainsElem { .. } | RangeContainsRange { .. } => {
ScalarType::Bool.nullable(in_nullable)
}
RangeContainsElem { .. }
| RangeContainsRange { .. }
| RangeOverlaps
| RangeAfter
| RangeBefore
| RangeOverleft
| RangeOverright
| RangeAdjacent => ScalarType::Bool.nullable(in_nullable),
}
}

Expand Down Expand Up @@ -2512,6 +2534,12 @@ impl BinaryFunc {
| ModNumeric
| RangeContainsElem { .. }
| RangeContainsRange { .. }
| RangeOverlaps
| RangeAfter
| RangeBefore
| RangeOverleft
| RangeOverright
| RangeAdjacent
)
}

Expand Down Expand Up @@ -2642,7 +2670,13 @@ impl BinaryFunc {
| ListElementConcat
| ElementListConcat
| RangeContainsElem { .. }
| RangeContainsRange { .. } => true,
| RangeContainsRange { .. }
| RangeOverlaps
| RangeAfter
| RangeBefore
| RangeOverleft
| RangeOverright
| RangeAdjacent => true,
ToCharTimestamp
| ToCharTimestampTz
| DateBinTimestamp
Expand Down Expand Up @@ -2904,6 +2938,12 @@ impl fmt::Display for BinaryFunc {
BinaryFunc::RangeContainsRange { rev, .. } => {
f.write_str(if *rev { "<@" } else { "@>" })
}
BinaryFunc::RangeOverlaps => f.write_str("&&"),
BinaryFunc::RangeAfter => f.write_str(">>"),
BinaryFunc::RangeBefore => f.write_str("<<"),
BinaryFunc::RangeOverleft => f.write_str("&<"),
BinaryFunc::RangeOverright => f.write_str("&>"),
BinaryFunc::RangeAdjacent => f.write_str("-|-"),
}
}
}
Expand Down Expand Up @@ -3105,9 +3145,15 @@ impl Arbitrary for BinaryFunc {
(bool::arbitrary(), mz_repr::arb_range_type())
.prop_map(|(rev, elem_type)| BinaryFunc::RangeContainsElem { elem_type, rev })
.boxed(),
(bool::arbitrary(), mz_repr::arb_range_type())
.prop_map(|(rev, elem_type)| BinaryFunc::RangeContainsRange { elem_type, rev })
bool::arbitrary()
.prop_map(|rev| BinaryFunc::RangeContainsRange { rev })
.boxed(),
Just(BinaryFunc::RangeOverlaps).boxed(),
Just(BinaryFunc::RangeAfter).boxed(),
Just(BinaryFunc::RangeBefore).boxed(),
Just(BinaryFunc::RangeOverleft).boxed(),
Just(BinaryFunc::RangeOverright).boxed(),
Just(BinaryFunc::RangeAdjacent).boxed(),
])
}
}
Expand Down Expand Up @@ -3290,12 +3336,13 @@ impl RustType<ProtoBinaryFunc> for BinaryFunc {
rev: *rev,
})
}
BinaryFunc::RangeContainsRange { elem_type, rev } => {
RangeContainsRange(crate::scalar::proto_binary_func::ProtoRangeContainsInner {
elem_type: Some(elem_type.into_proto()),
rev: *rev,
})
}
BinaryFunc::RangeContainsRange { rev } => RangeContainsRange(*rev),
BinaryFunc::RangeOverlaps => RangeOverlaps(()),
BinaryFunc::RangeAfter => RangeAfter(()),
BinaryFunc::RangeBefore => RangeBefore(()),
BinaryFunc::RangeOverleft => RangeOverleft(()),
BinaryFunc::RangeOverright => RangeOverright(()),
BinaryFunc::RangeAdjacent => RangeAdjacent(()),
};
ProtoBinaryFunc { kind: Some(kind) }
}
Expand Down Expand Up @@ -3484,12 +3531,13 @@ impl RustType<ProtoBinaryFunc> for BinaryFunc {
.into_rust_if_some("ProtoRangeContainsInner::elem_type")?,
rev: inner.rev,
}),
RangeContainsRange(inner) => Ok(BinaryFunc::RangeContainsRange {
elem_type: inner
.elem_type
.into_rust_if_some("ProtoRangeContainsInner::elem_type")?,
rev: inner.rev,
}),
RangeContainsRange(rev) => Ok(BinaryFunc::RangeContainsRange { rev }),
RangeOverlaps(()) => Ok(BinaryFunc::RangeOverlaps),
RangeAfter(()) => Ok(BinaryFunc::RangeAfter),
RangeBefore(()) => Ok(BinaryFunc::RangeBefore),
RangeOverleft(()) => Ok(BinaryFunc::RangeOverleft),
RangeOverright(()) => Ok(BinaryFunc::RangeOverright),
RangeAdjacent(()) => Ok(BinaryFunc::RangeAdjacent),
}
} else {
Err(TryFromProtoError::missing_field("ProtoBinaryFunc::kind"))
Expand Down
133 changes: 116 additions & 17 deletions src/repr/src/adt/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use proptest_derive::Arbitrary;
use serde::{Deserialize, Serialize};

use mz_lowertest::MzReflect;
use mz_ore::soft_assert;
use mz_proto::{RustType, TryFromProtoError};

use crate::scalar::DatumKind;
Expand Down Expand Up @@ -113,15 +112,7 @@ where
{
/// Increment `self` one step forward, if applicable. Return `None` if
/// overflows.
///
/// Types that do not have discrete steps should never call this function,
/// but we handle this with a soft assert because there's nothing inherently
/// wrong with it.
fn step(self) -> Option<Self> {
soft_assert!(
false,
"default implementation viable only for continuous value types, which should never be called"
);
Some(self)
}

Expand Down Expand Up @@ -206,7 +197,7 @@ impl<D> Range<D> {
}

/// Range implementations meant to work with `Range<Datum>` and `Range<DatumNested>`.
impl<'a, B: Copy + Ord + PartialOrd> Range<B>
impl<'a, B: Copy + Ord + PartialOrd + Display + Debug> Range<B>
where
Datum<'a>: From<B>,
{
Expand All @@ -220,16 +211,89 @@ where
}
}

pub fn contains_range<T: RangeOps<'a>>(&self, other: &Range<B>) -> bool
where
<T as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
{
pub fn contains_range(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(None, None) | (Some(_), None) => return true,
(None, Some(_)) => return false,
(None, None) | (Some(_), None) => true,
(None, Some(_)) => false,
(Some(i), Some(j)) => i.lower <= j.lower && j.upper <= i.upper,
}
}

pub fn overlaps(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
let r = match s.cmp(&o) {
Ordering::Equal => Ordering::Equal,
Ordering::Less => s.upper.range_bound_cmp(&o.lower),
Ordering::Greater => o.upper.range_bound_cmp(&s.lower),
};

// If smaller upper is >= larger lower, elements overlap.
matches!(r, Ordering::Greater | Ordering::Equal)
}
_ => false,
}
}

pub fn before(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
matches!(s.upper.range_bound_cmp(&o.lower), Ordering::Less)
}
_ => false,
}
}

pub fn after(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
matches!(s.lower.range_bound_cmp(&o.upper), Ordering::Greater)
}
_ => false,
}
}

pub fn overleft(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
matches!(
s.upper.range_bound_cmp(&o.upper),
Ordering::Less | Ordering::Equal
)
}
_ => false,
}
}

pub fn overright(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
matches!(
s.lower.range_bound_cmp(&o.lower),
Ordering::Greater | Ordering::Equal
)
}
_ => false,
}
}

pub fn adjacent(&self, other: &Range<B>) -> bool {
match (self.inner, other.inner) {
(Some(s), Some(o)) => {
// Look at each (lower, upper) pair.
for (lower, upper) in [(s.lower, o.upper), (o.lower, s.upper)] {
if let (Some(l), Some(u)) = (lower.bound, upper.bound) {
// If ..x](x.. or ..x)[x.., adjacent
if lower.inclusive ^ upper.inclusive && l == u {
return true;
}
}
}
false
}
_ => false,
}
}
}

impl<'a> Range<Datum<'a>> {
Expand Down Expand Up @@ -369,7 +433,7 @@ pub type RangeUpperBound<B> = RangeBound<B, true>;

// Generic RangeBound implementations meant to work over `RangeBound<Datum,..>`
// and `RangeBound<DatumNested,..>`.
impl<'a, const UPPER: bool, B: Copy> RangeBound<B, UPPER>
impl<'a, const UPPER: bool, B: Copy + Ord + PartialOrd + Display + Debug> RangeBound<B, UPPER>
where
Datum<'a>: From<B>,
{
Expand Down Expand Up @@ -403,6 +467,41 @@ where
Ordering::Less => !UPPER,
}
}

// Compares two `RangeBound`, which do not need to both be of the same
// `UPPER`.
fn range_bound_cmp<const OTHER_UPPER: bool>(
&self,
other: &RangeBound<B, OTHER_UPPER>,
) -> Ordering {
if UPPER == OTHER_UPPER {
return self.cmp(&RangeBound {
inclusive: other.inclusive,
bound: other.bound,
});
}

// Handle cases where either are infinite bounds, which have special
// semantics.
if self.bound.is_none() || other.bound.is_none() {
return if UPPER {
Ordering::Greater
} else {
Ordering::Less
};
}
// 1. Sort by bounds
let cmp = self.bound.cmp(&other.bound);
// 2. Tie break by sorting by inclusivity, which is inverted between
// lowers and uppers.
cmp.then(if self.inclusive && other.inclusive {
Ordering::Equal
} else if UPPER {
Ordering::Less
} else {
Ordering::Greater
})
}
}

impl<'a, const UPPER: bool> RangeBound<Datum<'a>, UPPER> {
Expand Down
Loading

0 comments on commit 6aed5dd

Please sign in to comment.