Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions rust/lance-core/src/utils/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,182 @@ pub fn bitmap_to_ranges(bitmap: &RoaringBitmap) -> Vec<Range<u64>> {
ranges
}

/// A set of stable row ids backed by a 64-bit Roaring bitmap.
///
/// This is a thin wrapper around [`RoaringTreemap`]. It represents a
/// collection of unique row ids and provides the common row-set
/// operations defined by [`RowSetOps`].
#[derive(Clone, Debug, Default, PartialEq)]
pub struct RowIdSet {
inner: RoaringTreemap,
}

impl RowIdSet {
/// Creates an empty set of row ids.
pub fn new() -> Self {
Self::default()
}
/// Returns an iterator over the contained row ids in ascending order.
pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
self.inner.iter()
}
/// Returns the union of `self` and `other`.
pub fn union(&self, other: &Self) -> Self {
let mut result = self.clone();
result.inner |= &other.inner;
result
}
/// Returns the set difference `self \\ other`.
pub fn difference(&self, other: &Self) -> Self {
let mut result = self.clone();
result.inner -= &other.inner;
result
}
}

impl RowSetOps for RowIdSet {
type Row = u64;
fn is_empty(&self) -> bool {
self.inner.is_empty()
}
fn len(&self) -> Option<u64> {
Some(self.inner.len())
}
fn remove(&mut self, row: Self::Row) -> bool {
self.inner.remove(row)
}
fn contains(&self, row: Self::Row) -> bool {
self.inner.contains(row)
}
fn union_all(other: &[&Self]) -> Self {
let mut iter = other.iter();
let mut result = iter.next().map(|set| (*set).clone()).unwrap_or_default();
for set in iter {
result.inner |= &set.inner;
}
result
}
#[track_caller]
fn from_sorted_iter<I>(iter: I) -> Result<Self>
where
I: IntoIterator<Item = Self::Row>,
{
let mut inner = RoaringTreemap::new();
let mut last: Option<u64> = None;
for value in iter {
if let Some(prev) = last {
if value < prev {
return Err(Error::Internal {
message: "RowIdSet::from_sorted_iter called with non-sorted input"
.to_string(),
// Use the caller location since we aren't the one that got it out of order
location: std::panic::Location::caller().to_snafu_location(),
});
}
}
inner.insert(value);
last = Some(value);
}
Ok(Self { inner })
}
}

/// A mask over stable row ids based on an allow-list or block-list.
///
/// The semantics mirror [`RowAddrMask`], but operate on stable
/// row ids instead of physical row addresses.
#[derive(Clone, Debug, PartialEq)]
pub enum RowIdMask {
/// Only the ids in the set are selected.
AllowList(RowIdSet),
/// All ids are selected except those in the set.
BlockList(RowIdSet),
}

impl Default for RowIdMask {
fn default() -> Self {
// Empty block list means all rows are allowed
Self::BlockList(RowIdSet::default())
}
}
impl RowIdMask {
/// Create a mask allowing all rows, this is an alias for [`Default`].
pub fn all_rows() -> Self {
Self::default()
}
/// Create a mask that doesn't allow any row id.
pub fn allow_nothing() -> Self {
Self::AllowList(RowIdSet::default())
}
/// Create a mask from an allow list.
pub fn from_allowed(allow_list: RowIdSet) -> Self {
Self::AllowList(allow_list)
}
/// Create a mask from a block list.
pub fn from_block(block_list: RowIdSet) -> Self {
Self::BlockList(block_list)
}
/// True if the row id is selected by the mask, false otherwise.
pub fn selected(&self, row_id: u64) -> bool {
match self {
Self::AllowList(allow_list) => allow_list.contains(row_id),
Self::BlockList(block_list) => !block_list.contains(row_id),
}
}
/// Return the indices of the input row ids that are selected by the mask.
pub fn selected_indices<'a>(&self, row_ids: impl Iterator<Item = &'a u64> + 'a) -> Vec<u64> {
row_ids
.enumerate()
.filter_map(|(idx, row_id)| {
if self.selected(*row_id) {
Some(idx as u64)
} else {
None
}
})
.collect()
}
/// Also block the given ids.
///
/// * `AllowList(a)` -> `AllowList(a \\ block_list)`
/// * `BlockList(b)` -> `BlockList(b union block_list)`
pub fn also_block(self, block_list: RowIdSet) -> Self {
match self {
Self::AllowList(allow_list) => Self::AllowList(allow_list.difference(&block_list)),
Self::BlockList(existing) => Self::BlockList(existing.union(&block_list)),
}
}
/// Also allow the given ids.
///
/// * `AllowList(a)` -> `AllowList(a union allow_list)`
/// * `BlockList(b)` -> `BlockList(b \\ allow_list)`
pub fn also_allow(self, allow_list: RowIdSet) -> Self {
match self {
Self::AllowList(existing) => Self::AllowList(existing.union(&allow_list)),
Self::BlockList(block_list) => Self::BlockList(block_list.difference(&allow_list)),
}
}
/// Return the maximum number of row ids that could be selected by this mask.
///
/// Will be `None` if this is a `BlockList` (unbounded).
pub fn max_len(&self) -> Option<u64> {
match self {
Self::AllowList(selection) => selection.len(),
Self::BlockList(_) => None,
}
}
/// Iterate over the row ids that are selected by the mask.
///
/// This is only possible if this is an `AllowList`. For a `BlockList`
/// the domain of possible row ids is unbounded.
pub fn iter_ids(&self) -> Option<Box<dyn Iterator<Item = u64> + '_>> {
match self {
Self::AllowList(allow_list) => Some(Box::new(allow_list.iter())),
Self::BlockList(_) => None,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down