Skip to content

ARROW-4749: [Rust] Return Result for RecordBatch::new() #3800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/arrow/benches/csv_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn record_batches_to_csv() {
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);

let b = RecordBatch::new(
let b = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
);
Expand Down
10 changes: 6 additions & 4 deletions rust/arrow/examples/dynamic_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ extern crate arrow;

use arrow::array::*;
use arrow::datatypes::*;
use arrow::error::Result;
use arrow::record_batch::*;

fn main() {
fn main() -> Result<()> {
// define schema
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Expand Down Expand Up @@ -58,9 +59,10 @@ fn main() {
]);

// build a record batch
let batch = RecordBatch::new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)]);
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)])?;

process(&batch);
Ok(process(&batch))
}

/// Create a new batch by performing a projection of id, nested.c
Expand Down Expand Up @@ -88,7 +90,7 @@ fn process(batch: &RecordBatch) {
Field::new("sum", DataType::Float64, false),
]);

let _ = RecordBatch::new(
let _ = RecordBatch::try_new(
Arc::new(projected_schema),
vec![
id.clone(), // NOTE: this is cloning the Arc not the array data
Expand Down
5 changes: 4 additions & 1 deletion rust/arrow/src/csv/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,10 @@ impl<R: Read> Reader<R> {
let projected_schema = Arc::new(Schema::new(projected_fields));

match arrays {
Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))),
Ok(arr) => match RecordBatch::try_new(projected_schema, arr) {
Ok(batch) => Ok(Some(batch)),
Err(e) => Err(e),
},
Err(e) => Err(e),
}
}
Expand Down
14 changes: 8 additions & 6 deletions rust/arrow/src/csv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@
//! let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
//! let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
//!
//! let batch = RecordBatch::new(
//! let batch = RecordBatch::try_new(
//! Arc::new(schema),
//! vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
//! );
//! ).unwrap();
//!
//! let file = get_temp_file("out.csv", &[]);
//!
Expand Down Expand Up @@ -287,10 +287,11 @@ mod tests {
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);

let batch = RecordBatch::new(
let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
);
)
.unwrap();

let file = get_temp_file("columns.csv", &[]);

Expand Down Expand Up @@ -331,10 +332,11 @@ mod tests {
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);

let batch = RecordBatch::new(
let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
);
)
.unwrap();

let file = get_temp_file("custom_options.csv", &[]);

Expand Down
1 change: 1 addition & 0 deletions rust/arrow/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub enum ArrowError {
CsvError(String),
JsonError(String),
IoError(String),
InvalidArgumentError(String),
}

impl From<::std::io::Error> for ArrowError {
Expand Down
5 changes: 4 additions & 1 deletion rust/arrow/src/json/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,10 @@ impl<R: Read> Reader<R> {
let projected_schema = Arc::new(Schema::new(projected_fields));

match arrays {
Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))),
Ok(arr) => match RecordBatch::try_new(projected_schema, arr) {
Ok(batch) => Ok(Some(batch)),
Err(e) => Err(e),
},
Err(e) => Err(e),
}
}
Expand Down
79 changes: 64 additions & 15 deletions rust/arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use std::sync::Arc;

use crate::array::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};

/// A batch of column-oriented data
#[derive(Clone)]
Expand All @@ -34,36 +35,61 @@ pub struct RecordBatch {
}

impl RecordBatch {
pub fn new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Self {
// assert that there are some columns
assert!(
columns.len() > 0,
"at least one column must be defined to create a record batch"
);
// assert that all columns have the same row count
/// Creates a `RecordBatch` from a schema and columns
///
/// Expects the following:
/// * the vec of columns to not be empty
/// * the schema and column data types to have equal lengths and match
/// * each array in columns to have the same length
pub fn try_new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Result<Self> {
// check that there are some columns
if columns.is_empty() {
return Err(ArrowError::InvalidArgumentError(
"at least one column must be defined to create a record batch"
.to_string(),
));
}
// check that number of fields in schema match column length
if schema.fields().len() != columns.len() {
return Err(ArrowError::InvalidArgumentError(
"number of columns must match number of fields in schema".to_string(),
));
}
// check that all columns have the same row count, and match the schema
let len = columns[0].data().len();
for i in 1..columns.len() {
assert_eq!(
len,
columns[i].len(),
"all columns in a record batch must have the same length"
);
for i in 0..columns.len() {
if columns[i].len() != len {
return Err(ArrowError::InvalidArgumentError(
"all columns in a record batch must have the same length".to_string(),
));
}
if columns[i].data_type() != schema.field(i).data_type() {
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {:?} but found {:?} at column index {}",
schema.field(i).data_type(),
columns[i].data_type(),
i)));
}
}
RecordBatch { schema, columns }
Ok(RecordBatch { schema, columns })
}

/// Returns the schema of the record batch
pub fn schema(&self) -> &Arc<Schema> {
&self.schema
}

/// Number of columns in the record batch
pub fn num_columns(&self) -> usize {
self.columns.len()
}

/// Number of rows in each column
pub fn num_rows(&self) -> usize {
self.columns[0].data().len()
}

/// Get a reference to a column's array by index
pub fn column(&self, i: usize) -> &ArrayRef {
&self.columns[i]
}
Expand Down Expand Up @@ -103,7 +129,8 @@ mod tests {
let b = BinaryArray::from(array_data);

let record_batch =
RecordBatch::new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
.unwrap();

assert_eq!(5, record_batch.num_rows());
assert_eq!(2, record_batch.num_columns());
Expand All @@ -112,4 +139,26 @@ mod tests {
assert_eq!(5, record_batch.column(0).data().len());
assert_eq!(5, record_batch.column(1).data().len());
}

#[test]
fn create_record_batch_schema_mismatch() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int64Array::from(vec![1, 2, 3, 4, 5]);

let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
assert!(!batch.is_ok());
}

#[test]
fn create_record_batch_record_mismatch() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let b = Int32Array::from(vec![1, 2, 3, 4, 5]);

let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
assert!(!batch.is_ok());
}
}
53 changes: 31 additions & 22 deletions rust/datafusion/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,25 @@ impl Table for MemTable {

let projected_schema = Arc::new(Schema::new(projected_columns?));

Ok(Rc::new(RefCell::new(MemBatchIterator {
schema: projected_schema.clone(),
index: 0,
batches: self
.batches
.iter()
.map(|batch| {
RecordBatch::new(
projected_schema.clone(),
columns.iter().map(|i| batch.column(*i).clone()).collect(),
)
})
.collect(),
})))
let batches = self
.batches
.iter()
.map(|batch| {
RecordBatch::try_new(
projected_schema.clone(),
columns.iter().map(|i| batch.column(*i).clone()).collect(),
)
})
.collect();

match batches {
Ok(batches) => Ok(Rc::new(RefCell::new(MemBatchIterator {
schema: projected_schema.clone(),
index: 0,
batches,
}))),
Err(e) => Err(ExecutionError::ArrowError(e)),
}
}
}

Expand Down Expand Up @@ -155,14 +160,15 @@ mod tests {
Field::new("c", DataType::Int32, false),
]));

let batch = RecordBatch::new(
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
);
)
.unwrap();

let provider = MemTable::new(schema, vec![batch]).unwrap();

Expand All @@ -183,14 +189,15 @@ mod tests {
Field::new("c", DataType::Int32, false),
]));

let batch = RecordBatch::new(
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
);
)
.unwrap();

let provider = MemTable::new(schema, vec![batch]).unwrap();

Expand All @@ -208,14 +215,15 @@ mod tests {
Field::new("c", DataType::Int32, false),
]));

let batch = RecordBatch::new(
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
);
)
.unwrap();

let provider = MemTable::new(schema, vec![batch]).unwrap();

Expand Down Expand Up @@ -243,14 +251,15 @@ mod tests {
Field::new("c", DataType::Int32, false),
]));

let batch = RecordBatch::new(
let batch = RecordBatch::try_new(
schema1.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
);
)
.unwrap();

match MemTable::new(schema2, vec![batch]) {
Err(ExecutionError::General(e)) => assert_eq!(
Expand Down
12 changes: 9 additions & 3 deletions rust/datafusion/src/execution/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,10 @@ impl AggregateRelation {
}
}

Ok(Some(RecordBatch::new(self.schema.clone(), result_columns)))
Ok(Some(RecordBatch::try_new(
self.schema.clone(),
result_columns,
)?))
}

fn with_group_by(&mut self) -> Result<Option<RecordBatch>> {
Expand Down Expand Up @@ -1008,7 +1011,10 @@ impl AggregateRelation {
result_arrays.push(array?);
}

Ok(Some(RecordBatch::new(self.schema.clone(), result_arrays)))
Ok(Some(RecordBatch::try_new(
self.schema.clone(),
result_arrays,
)?))
}
}

Expand Down Expand Up @@ -1136,7 +1142,7 @@ mod tests {
.unwrap();

let aggr_schema = Arc::new(Schema::new(vec![
Field::new("c2", DataType::Int32, false),
Field::new("c2", DataType::UInt32, false),
Field::new("min", DataType::Float64, false),
Field::new("max", DataType::Float64, false),
Field::new("sum", DataType::Float64, false),
Expand Down
9 changes: 8 additions & 1 deletion rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,15 @@ impl ExecutionContext {
.collect();
let compiled_aggr_expr = compiled_aggr_expr_result?;

let mut output_fields: Vec<Field> = vec![];
for expr in group_expr {
output_fields.push(expr_to_field(expr, input_schema.as_ref()));
}
for expr in aggr_expr {
output_fields.push(expr_to_field(expr, input_schema.as_ref()));
}
let rel = AggregateRelation::new(
Arc::new(Schema::empty()), //(expr_to_field(&compiled_group_expr, &input_schema))),
Arc::new(Schema::new(output_fields)),
input_rel,
compiled_group_expr,
compiled_aggr_expr,
Expand Down
Loading