Skip to content

Commit

Permalink
fix: Assert chunks are equal after physical cast to prevent OOB (#14873)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Mar 6, 2024
1 parent 041cdf0 commit 33b2286
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 6 deletions.
21 changes: 21 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,27 @@ impl DataType {
matches!(self, DataType::Boolean)
}

/// Check if this [`DataType`] is a list
pub fn is_list(&self) -> bool {
matches!(self, DataType::List(_))
}

pub fn is_nested(&self) -> bool {
self.is_list() || self.is_struct()
}

/// Check if this [`DataType`] is a struct
pub fn is_struct(&self) -> bool {
#[cfg(feature = "dtype-struct")]
{
matches!(self, DataType::Struct(_))
}
#[cfg(not(feature = "dtype-struct"))]
{
false
}
}

pub fn is_binary(&self) -> bool {
matches!(self, DataType::Binary)
}
Expand Down
23 changes: 21 additions & 2 deletions crates/polars-ops/src/chunked_array/gather/chunked.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::borrow::Cow;
use std::fmt::Debug;

use polars_core::prelude::gather::_update_gather_sorted_flag;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
Expand Down Expand Up @@ -67,9 +70,24 @@ pub trait TakeChunked {
unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self;
}

fn prepare_series(s: &Series) -> Cow<Series> {
let phys = if s.dtype().is_nested() {
Cow::Borrowed(s)
} else {
s.to_physical_repr()
};
// If this is hit the cast rechunked the data and the gather will OOB
assert_eq!(
phys.chunks().len(),
s.chunks().len(),
"implementation error"
);
phys
}

impl TakeChunked for Series {
unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self {
let phys = self.to_physical_repr();
let phys = prepare_series(self);
use DataType::*;
let out = match phys.dtype() {
dt if dt.is_numeric() => {
Expand Down Expand Up @@ -122,7 +140,7 @@ impl TakeChunked for Series {

/// Take function that checks of null state in `ChunkIdx`.
unsafe fn take_opt_chunked_unchecked(&self, by: &[NullableChunkId]) -> Self {
let phys = self.to_physical_repr();
let phys = prepare_series(self);
use DataType::*;
let out = match phys.dtype() {
dt if dt.is_numeric() => {
Expand Down Expand Up @@ -177,6 +195,7 @@ impl TakeChunked for Series {
impl<T> TakeChunked for ChunkedArray<T>
where
T: PolarsDataType,
T::Array: Debug,
{
unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self {
let arrow_dtype = self.dtype().to_arrow(true);
Expand Down
11 changes: 7 additions & 4 deletions crates/polars-ops/src/frame/join/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ pub trait JoinDispatch: IntoDf {
args: JoinArgs,
verbose: bool,
) -> PolarsResult<DataFrame> {
let ca_self = self.to_df();
let df_self = self.to_df();
#[cfg(feature = "dtype-categorical")]
_check_categorical_src(s_left.dtype(), s_right.dtype())?;

let mut left = ca_self.clone();
let mut left = df_self.clone();
let mut s_left = s_left.clone();
// Eagerly limit left if possible.
if let Some((offset, len)) = args.slice {
Expand All @@ -188,16 +188,19 @@ pub trait JoinDispatch: IntoDf {
}

// Ensure that the chunks are aligned otherwise we go OOB.
let mut right = other.clone();
let mut right = Cow::Borrowed(other);
let mut s_right = s_right.clone();
if left.should_rechunk() {
left.as_single_chunk_par();
s_left = s_left.rechunk();
}
if right.should_rechunk() {
right.as_single_chunk_par();
let mut other = other.clone();
other.as_single_chunk_par();
right = Cow::Owned(other);
s_right = s_right.rechunk();
}

let ids = sort_or_hash_left(&s_left, &s_right, verbose, args.validation, args.join_nulls)?;
left._finish_left_join(ids, &right.drop(s_right.name()).unwrap(), args)
}
Expand Down
2 changes: 2 additions & 0 deletions crates/polars/tests/it/chunks/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#[cfg(feature = "parquet")]
mod parquet;
38 changes: 38 additions & 0 deletions crates/polars/tests/it/chunks/parquet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::io::{Seek, SeekFrom};

use polars::prelude::*;

#[test]
fn test_cast_join_14872() {
let df1 = df![
"ints" => [1]
]
.unwrap();

let mut df2 = df![
"ints" => [0, 1],
"strings" => vec![Series::new("", ["a"]); 2],
]
.unwrap();

let mut buf = std::io::Cursor::new(vec![]);
ParquetWriter::new(&mut buf)
.with_row_group_size(Some(1))
.finish(&mut df2)
.unwrap();

let _ = buf.seek(SeekFrom::Start(0));
let df2 = ParquetReader::new(buf).finish().unwrap();

let out = df1
.join(&df2, ["ints"], ["ints"], JoinArgs::new(JoinType::Left))
.unwrap();

let expected = df![
"ints" => [1],
"strings" => vec![Series::new("", ["a"]); 1],
]
.unwrap();

assert!(expected.equals(&out));
}
1 change: 1 addition & 0 deletions crates/polars/tests/it/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ mod schema;
mod time;

mod arrow;
mod chunks;

pub static FOODS_CSV: &str = "../../examples/datasets/foods1.csv";

0 comments on commit 33b2286

Please sign in to comment.