-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add IPC FileDecoder #5249
Add IPC FileDecoder #5249
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -451,7 +451,7 @@ pub fn read_dictionary( | |
batch: crate::DictionaryBatch, | ||
schema: &Schema, | ||
dictionaries_by_id: &mut HashMap<i64, ArrayRef>, | ||
metadata: &crate::MetadataVersion, | ||
metadata: &MetadataVersion, | ||
) -> Result<(), ArrowError> { | ||
if batch.isDelta() { | ||
return Err(ArrowError::InvalidArgumentError( | ||
|
@@ -522,6 +522,174 @@ fn parse_message(buf: &[u8]) -> Result<Message, ArrowError> { | |
.map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))) | ||
} | ||
|
||
/// Read the footer length from the last 10 bytes of an Arrow IPC file | ||
/// | ||
/// Expects a 4 byte footer length followed by `b"ARROW1"` | ||
pub fn read_footer_length(buf: [u8; 10]) -> Result<usize, ArrowError> { | ||
if buf[4..] != super::ARROW_MAGIC { | ||
return Err(ArrowError::ParseError( | ||
"Arrow file does not contain correct footer".to_string(), | ||
)); | ||
} | ||
|
||
// read footer length | ||
let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap()); | ||
footer_len | ||
.try_into() | ||
.map_err(|_| ArrowError::ParseError(format!("Invalid footer length: {footer_len}"))) | ||
} | ||
|
||
/// A low-level, push-based interface for reading an IPC file | ||
/// | ||
/// For a higher-level interface see [`FileReader`] | ||
/// | ||
/// ``` | ||
/// # use std::sync::Arc; | ||
/// # use arrow_array::*; | ||
/// # use arrow_array::types::Int32Type; | ||
/// # use arrow_buffer::Buffer; | ||
/// # use arrow_ipc::convert::fb_to_schema; | ||
/// # use arrow_ipc::reader::{FileDecoder, read_footer_length}; | ||
/// # use arrow_ipc::root_as_footer; | ||
/// # use arrow_ipc::writer::FileWriter; | ||
/// // Write an IPC file | ||
/// | ||
/// let batch = RecordBatch::try_from_iter([ | ||
/// ("a", Arc::new(Int32Array::from(vec![1, 2, 3])) as _), | ||
/// ("b", Arc::new(Int32Array::from(vec![1, 2, 3])) as _), | ||
/// ("c", Arc::new(DictionaryArray::<Int32Type>::from_iter(["hello", "hello", "world"])) as _), | ||
/// ]).unwrap(); | ||
/// | ||
/// let schema = batch.schema(); | ||
/// | ||
/// let mut out = Vec::with_capacity(1024); | ||
/// let mut writer = FileWriter::try_new(&mut out, schema.as_ref()).unwrap(); | ||
/// writer.write(&batch).unwrap(); | ||
/// writer.finish().unwrap(); | ||
/// | ||
/// drop(writer); | ||
/// | ||
/// // Read IPC file | ||
/// | ||
/// let buffer = Buffer::from_vec(out); | ||
/// let trailer_start = buffer.len() - 10; | ||
/// let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap(); | ||
/// let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); | ||
/// | ||
/// let back = fb_to_schema(footer.schema().unwrap()); | ||
/// assert_eq!(&back, schema.as_ref()); | ||
/// | ||
/// let mut decoder = FileDecoder::new(schema, footer.version()); | ||
/// | ||
/// // Read dictionaries | ||
/// for block in footer.dictionaries().iter().flatten() { | ||
/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; | ||
/// let data = buffer.slice_with_length(block.offset() as _, block_len); | ||
/// decoder.read_dictionary(&block, &data).unwrap(); | ||
/// } | ||
/// | ||
/// // Read record batch | ||
/// let batches = footer.recordBatches().unwrap(); | ||
/// assert_eq!(batches.len(), 1); // Only wrote a single batch | ||
/// | ||
/// let block = batches.get(0); | ||
/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; | ||
/// let data = buffer.slice_with_length(block.offset() as _, block_len); | ||
/// let back = decoder.read_record_batch(block, &data).unwrap().unwrap(); | ||
/// | ||
/// assert_eq!(batch, back); | ||
/// ``` | ||
#[derive(Debug)] | ||
pub struct FileDecoder { | ||
schema: SchemaRef, | ||
dictionaries: HashMap<i64, ArrayRef>, | ||
version: MetadataVersion, | ||
projection: Option<Vec<usize>>, | ||
} | ||
|
||
impl FileDecoder { | ||
/// Create a new [`FileDecoder`] with the given schema and version | ||
pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self { | ||
Self { | ||
schema, | ||
version, | ||
dictionaries: Default::default(), | ||
projection: None, | ||
} | ||
} | ||
|
||
/// Specify a projection | ||
pub fn with_projection(mut self, projection: Vec<usize>) -> Self { | ||
self.projection = Some(projection); | ||
self | ||
} | ||
|
||
fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, ArrowError> { | ||
let message = parse_message(buf)?; | ||
|
||
// some old test data's footer metadata is not set, so we account for that | ||
if self.version != MetadataVersion::V1 && message.version() != self.version { | ||
return Err(ArrowError::IpcError( | ||
"Could not read IPC message as metadata versions mismatch".to_string(), | ||
)); | ||
} | ||
Ok(message) | ||
} | ||
|
||
/// Read the dictionary with the given block and data buffer | ||
pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> { | ||
let message = self.read_message(buf)?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not bit deal but seems we didn't check for metadata version for |
||
match message.header_type() { | ||
crate::MessageHeader::DictionaryBatch => { | ||
let batch = message.header_as_dictionary_batch().unwrap(); | ||
read_dictionary( | ||
&buf.slice(block.metaDataLength() as _), | ||
batch, | ||
&self.schema, | ||
&mut self.dictionaries, | ||
&message.version(), | ||
) | ||
} | ||
t => Err(ArrowError::ParseError(format!( | ||
"Expecting DictionaryBatch in dictionary blocks, found {t:?}." | ||
))), | ||
} | ||
} | ||
|
||
/// Read the RecordBatch with the given block and data buffer | ||
pub fn read_record_batch( | ||
&self, | ||
block: &Block, | ||
buf: &Buffer, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is perhaps worth noting that this interface won't allow pushing down column projection to IO, I think this is a bridge we cross when we add support for this. |
||
) -> Result<Option<RecordBatch>, ArrowError> { | ||
let message = self.read_message(buf)?; | ||
match message.header_type() { | ||
crate::MessageHeader::Schema => Err(ArrowError::IpcError( | ||
"Not expecting a schema when messages are read".to_string(), | ||
)), | ||
crate::MessageHeader::RecordBatch => { | ||
let batch = message.header_as_record_batch().ok_or_else(|| { | ||
ArrowError::IpcError("Unable to read IPC message as record batch".to_string()) | ||
})?; | ||
// read the block that makes up the record batch into a buffer | ||
read_record_batch( | ||
&buf.slice(block.metaDataLength() as _), | ||
batch, | ||
self.schema.clone(), | ||
&self.dictionaries, | ||
self.projection.as_deref(), | ||
&message.version(), | ||
) | ||
.map(Some) | ||
} | ||
crate::MessageHeader::NONE => Ok(None), | ||
t => Err(ArrowError::InvalidArgumentError(format!( | ||
"Reading types other than record batches not yet supported, unable to read {t:?}" | ||
))), | ||
} | ||
} | ||
} | ||
|
||
/// Build an Arrow [`FileReader`] with custom options. | ||
#[derive(Debug)] | ||
pub struct FileReaderBuilder { | ||
|
@@ -599,17 +767,10 @@ impl FileReaderBuilder { | |
reader.seek(SeekFrom::End(-10))?; | ||
reader.read_exact(&mut buffer)?; | ||
|
||
if buffer[4..] != super::ARROW_MAGIC { | ||
return Err(ArrowError::ParseError( | ||
"Arrow file does not contain correct footer".to_string(), | ||
)); | ||
} | ||
|
||
// read footer length | ||
let footer_len = i32::from_le_bytes(buffer[..4].try_into().unwrap()); | ||
let footer_len = read_footer_length(buffer)?; | ||
|
||
// read footer | ||
let mut footer_data = vec![0; footer_len as usize]; | ||
let mut footer_data = vec![0; footer_len]; | ||
reader.seek(SeekFrom::End(-10 - footer_len as i64))?; | ||
reader.read_exact(&mut footer_data)?; | ||
|
||
|
@@ -641,50 +802,26 @@ impl FileReaderBuilder { | |
} | ||
} | ||
|
||
let mut decoder = FileDecoder::new(Arc::new(schema), footer.version()); | ||
if let Some(projection) = self.projection { | ||
decoder = decoder.with_projection(projection) | ||
} | ||
|
||
// Create an array of optional dictionary value arrays, one per field. | ||
let mut dictionaries_by_id = HashMap::new(); | ||
if let Some(dictionaries) = footer.dictionaries() { | ||
for block in dictionaries { | ||
let buf = read_block(&mut reader, block)?; | ||
let message = parse_message(&buf)?; | ||
|
||
match message.header_type() { | ||
crate::MessageHeader::DictionaryBatch => { | ||
let batch = message.header_as_dictionary_batch().unwrap(); | ||
read_dictionary( | ||
&buf.slice(block.metaDataLength() as _), | ||
batch, | ||
&schema, | ||
&mut dictionaries_by_id, | ||
&message.version(), | ||
)?; | ||
} | ||
t => { | ||
return Err(ArrowError::ParseError(format!( | ||
"Expecting DictionaryBatch in dictionary blocks, found {t:?}." | ||
))); | ||
} | ||
} | ||
decoder.read_dictionary(block, &buf)?; | ||
} | ||
} | ||
let projection = match self.projection { | ||
Some(projection_indices) => { | ||
let schema = schema.project(&projection_indices)?; | ||
Some((projection_indices, schema)) | ||
} | ||
_ => None, | ||
}; | ||
|
||
Ok(FileReader { | ||
reader, | ||
schema: Arc::new(schema), | ||
blocks: blocks.iter().copied().collect(), | ||
current_block: 0, | ||
total_blocks, | ||
dictionaries_by_id, | ||
metadata_version: footer.version(), | ||
decoder, | ||
custom_metadata, | ||
projection, | ||
}) | ||
} | ||
} | ||
|
@@ -694,45 +831,31 @@ pub struct FileReader<R: Read + Seek> { | |
/// Buffered file reader that supports reading and seeking | ||
reader: R, | ||
|
||
/// The schema that is read from the file header | ||
schema: SchemaRef, | ||
/// The decoder | ||
decoder: FileDecoder, | ||
|
||
/// The blocks in the file | ||
/// | ||
/// A block indicates the regions in the file to read to get data | ||
blocks: Vec<crate::Block>, | ||
blocks: Vec<Block>, | ||
|
||
/// A counter to keep track of the current block that should be read | ||
current_block: usize, | ||
|
||
/// The total number of blocks, which may contain record batches and other types | ||
total_blocks: usize, | ||
|
||
/// Optional dictionaries for each schema field. | ||
/// | ||
/// Dictionaries may be appended to in the streaming format. | ||
dictionaries_by_id: HashMap<i64, ArrayRef>, | ||
|
||
/// Metadata version | ||
metadata_version: crate::MetadataVersion, | ||
|
||
/// User defined metadata | ||
custom_metadata: HashMap<String, String>, | ||
|
||
/// Optional projection and projected_schema | ||
projection: Option<(Vec<usize>, Schema)>, | ||
} | ||
|
||
impl<R: Read + Seek> fmt::Debug for FileReader<R> { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { | ||
f.debug_struct("FileReader<R>") | ||
.field("schema", &self.schema) | ||
.field("decoder", &self.decoder) | ||
.field("blocks", &self.blocks) | ||
.field("current_block", &self.current_block) | ||
.field("total_blocks", &self.total_blocks) | ||
.field("dictionaries_by_id", &self.dictionaries_by_id) | ||
.field("metadata_version", &self.metadata_version) | ||
.field("projection", &self.projection) | ||
.finish_non_exhaustive() | ||
} | ||
} | ||
|
@@ -761,7 +884,7 @@ impl<R: Read + Seek> FileReader<R> { | |
|
||
/// Return the schema of the file | ||
pub fn schema(&self) -> SchemaRef { | ||
self.schema.clone() | ||
self.decoder.schema.clone() | ||
} | ||
|
||
/// Read a specific record batch | ||
|
@@ -785,41 +908,7 @@ impl<R: Read + Seek> FileReader<R> { | |
|
||
// read length | ||
let buffer = read_block(&mut self.reader, block)?; | ||
let message = parse_message(&buffer)?; | ||
|
||
// some old test data's footer metadata is not set, so we account for that | ||
if self.metadata_version != MetadataVersion::V1 | ||
&& message.version() != self.metadata_version | ||
{ | ||
return Err(ArrowError::IpcError( | ||
"Could not read IPC message as metadata versions mismatch".to_string(), | ||
)); | ||
} | ||
|
||
match message.header_type() { | ||
crate::MessageHeader::Schema => Err(ArrowError::IpcError( | ||
"Not expecting a schema when messages are read".to_string(), | ||
)), | ||
crate::MessageHeader::RecordBatch => { | ||
let batch = message.header_as_record_batch().ok_or_else(|| { | ||
ArrowError::IpcError("Unable to read IPC message as record batch".to_string()) | ||
})?; | ||
// read the block that makes up the record batch into a buffer | ||
read_record_batch( | ||
&buffer.slice(block.metaDataLength() as _), | ||
batch, | ||
self.schema(), | ||
&self.dictionaries_by_id, | ||
self.projection.as_ref().map(|x| x.0.as_ref()), | ||
&message.version(), | ||
) | ||
.map(Some) | ||
} | ||
crate::MessageHeader::NONE => Ok(None), | ||
t => Err(ArrowError::InvalidArgumentError(format!( | ||
"Reading types other than record batches not yet supported, unable to read {t:?}" | ||
))), | ||
} | ||
self.decoder.read_record_batch(block, &buffer) | ||
} | ||
|
||
/// Gets a reference to the underlying reader. | ||
|
@@ -852,7 +941,7 @@ impl<R: Read + Seek> Iterator for FileReader<R> { | |
|
||
impl<R: Read + Seek> RecordBatchReader for FileReader<R> { | ||
fn schema(&self) -> SchemaRef { | ||
self.schema.clone() | ||
self.schema() | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still somewhat verbose, but I couldn't see an easy way to reduce this further that didn't end up in knots with self-referential structs (as flatbuffers borrow data)