Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc};

use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};

use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
use bytes::Bytes;
Expand Down Expand Up @@ -647,6 +647,7 @@ struct FlightIpcEncoder {
options: IpcWriteOptions,
data_gen: IpcDataGenerator,
dictionary_tracker: DictionaryTracker,
compression_context: CompressionContext,
}

impl FlightIpcEncoder {
Expand All @@ -655,6 +656,7 @@ impl FlightIpcEncoder {
options,
data_gen: IpcDataGenerator::default(),
dictionary_tracker: DictionaryTracker::new(error_on_replacement),
compression_context: CompressionContext::default(),
}
}

Expand All @@ -666,9 +668,12 @@ impl FlightIpcEncoder {
/// Convert a `RecordBatch` to a Vec of `FlightData` representing
/// dictionaries and a `FlightData` representing the batch
fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
let (encoded_dictionaries, encoded_batch) =
self.data_gen
.encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?;
let (encoded_dictionaries, encoded_batch) = self.data_gen.encode(
batch,
&mut self.dictionary_tracker,
&self.options,
&mut self.compression_context,
)?;

let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
let flight_batch = encoded_batch.into();
Expand Down Expand Up @@ -1596,9 +1601,15 @@ mod tests {
) -> (Vec<FlightData>, FlightData) {
let data_gen = IpcDataGenerator::default();
let mut dictionary_tracker = DictionaryTracker::new(false);
let mut compression_context = CompressionContext::default();

let (encoded_dictionaries, encoded_batch) = data_gen
.encoded_batch(batch, &mut dictionary_tracker, options)
.encode(
batch,
&mut dictionary_tracker,
options,
&mut compression_context,
)
.expect("DictionaryTracker configured above to not error on replacement");

let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
Expand Down
10 changes: 8 additions & 2 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::Buffer;
use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::writer::CompressionContext;
use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions};
use arrow_schema::{ArrowError, Schema, SchemaRef};

Expand Down Expand Up @@ -91,10 +92,15 @@ pub fn batches_to_flight_data(

let data_gen = writer::IpcDataGenerator::default();
let mut dictionary_tracker = writer::DictionaryTracker::new(false);
let mut compression_context = CompressionContext::default();

for batch in batches.iter() {
let (encoded_dictionaries, encoded_batch) =
data_gen.encoded_batch(batch, &mut dictionary_tracker, &options)?;
let (encoded_dictionaries, encoded_batch) = data_gen.encode(
batch,
&mut dictionary_tracker,
&options,
&mut compression_context,
)?;

dictionaries.extend(encoded_dictionaries.into_iter().map(Into::into));
flight_data.push(encoded_batch.into());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use arrow::{
array::ArrayRef,
buffer::Buffer,
datatypes::SchemaRef,
ipc::{self, reader, writer},
ipc::{
self, reader,
writer::{self, CompressionContext},
},
record_batch::RecordBatch,
};
use arrow_flight::{
Expand Down Expand Up @@ -90,6 +93,8 @@ async fn upload_data(

let mut original_data_iter = original_data.iter().enumerate();

let mut compression_context = CompressionContext::default();

if let Some((counter, first_batch)) = original_data_iter.next() {
let metadata = counter.to_string().into_bytes();
// Preload the first batch into the channel before starting the request
Expand All @@ -99,6 +104,7 @@ async fn upload_data(
first_batch,
&options,
&mut dict_tracker,
&mut compression_context,
)
.await?;

Expand All @@ -121,6 +127,7 @@ async fn upload_data(
batch,
&options,
&mut dict_tracker,
&mut compression_context,
)
.await?;

Expand Down Expand Up @@ -150,11 +157,12 @@ async fn send_batch(
batch: &RecordBatch,
options: &writer::IpcWriteOptions,
dictionary_tracker: &mut writer::DictionaryTracker,
compression_context: &mut CompressionContext,
) -> Result {
let data_gen = writer::IpcDataGenerator::default();

let (encoded_dictionaries, encoded_batch) = data_gen
.encoded_batch(batch, dictionary_tracker, options)
.encode(batch, dictionary_tracker, options, compression_context)
.expect("DictionaryTracker configured above to not error on replacement");

let dictionary_flight_data: Vec<FlightData> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ impl FlightService for FlightServiceImpl {
.enumerate()
.flat_map(|(counter, batch)| {
let (encoded_dictionaries, encoded_batch) = data_gen
.encoded_batch(batch, &mut dictionary_tracker, &options)
.encode(
batch,
&mut dictionary_tracker,
&options,
&mut Default::default(),
)
.expect("DictionaryTracker configured above to not error on replacement");

let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into);
Expand Down
73 changes: 62 additions & 11 deletions arrow-ipc/src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,41 @@ use arrow_schema::ArrowError;
const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
const LENGTH_OF_PREFIX_DATA: i64 = 8;

/// Additional context that may be needed for compression.
///
/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent
/// compression calls to avoid the performance overhead of initialising a new context for every
/// compression.
pub struct CompressionContext {
#[cfg(feature = "zstd")]
compressor: zstd::bulk::Compressor<'static>,
Copy link
Contributor

@Dandandan Dandandan Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This always contains zstd::bulk::Compressor even when using lz4 compression?

Copy link
Contributor

@Dandandan Dandandan Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the same for lz4_flex::frame::FrameEncoder, does it help?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This always contains zstd::bulk::Compressor even when using lz4 compression?

If using lz4 compression, I imagine the the zstd feature wouldn't be enabled and this would just be an empty struct, right?

Can we do the same for lz4_flex::frame::FrameEncoder, does it help?

I imagine that it probably would help, although I didn't investigate as my use case was only focussed on zstd. My motivation behind adding this CompressionContext was that eventually it would be a good place to put something like this for lz4. Maybe we could do this in a followup issue/PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

// the reason we allow derivable_impls here is because when zstd feature is not enabled, this
// becomes derivable. however with zstd feature want to be explicit about the compression level.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for this context

#[allow(clippy::derivable_impls)]
impl Default for CompressionContext {
fn default() -> Self {
CompressionContext {
// safety: `new` here will only return error here if using an invalid compression level
#[cfg(feature = "zstd")]
compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
.expect("can use default compression level"),
}
}
}

impl std::fmt::Debug for CompressionContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut ds = f.debug_struct("CompressionContext");

#[cfg(feature = "zstd")]
ds.field("compressor", &"zstd::bulk::Compressor");

ds.finish()
}
}

/// Represents compressing a ipc stream using a particular compression algorithm
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionCodec {
Expand Down Expand Up @@ -58,6 +93,7 @@ impl CompressionCodec {
&self,
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
) -> Result<usize, ArrowError> {
let uncompressed_data_len = input.len();
let original_output_len = output.len();
Expand All @@ -67,7 +103,7 @@ impl CompressionCodec {
} else {
// write compressed data directly into the output buffer
output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
self.compress(input, output)?;
self.compress(input, output, context)?;

let compression_len = output.len() - original_output_len;
if compression_len > uncompressed_data_len {
Expand Down Expand Up @@ -115,10 +151,15 @@ impl CompressionCodec {

/// Compress the data in input buffer and write to output buffer
/// using the specified compression
fn compress(&self, input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
fn compress(
&self,
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
) -> Result<(), ArrowError> {
match self {
CompressionCodec::Lz4Frame => compress_lz4(input, output),
CompressionCodec::Zstd => compress_zstd(input, output),
CompressionCodec::Zstd => compress_zstd(input, output, context),
}
}

Expand Down Expand Up @@ -175,17 +216,23 @@ fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, A
}

#[cfg(feature = "zstd")]
fn compress_zstd(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
use std::io::Write;
let mut encoder = zstd::Encoder::new(output, 0)?;
encoder.write_all(input)?;
encoder.finish()?;
fn compress_zstd(
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
) -> Result<(), ArrowError> {
let result = context.compressor.compress(input)?;
output.extend_from_slice(&result);
Ok(())
}

#[cfg(not(feature = "zstd"))]
#[allow(clippy::ptr_arg)]
fn compress_zstd(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
fn compress_zstd(
_input: &[u8],
_output: &mut Vec<u8>,
_context: &mut CompressionContext,
) -> Result<(), ArrowError> {
Err(ArrowError::InvalidArgumentError(
"zstd IPC compression requires the zstd feature".to_string(),
))
Expand Down Expand Up @@ -227,7 +274,9 @@ mod tests {
let input_bytes = b"hello lz4";
let codec = super::CompressionCodec::Lz4Frame;
let mut output_bytes: Vec<u8> = Vec::new();
codec.compress(input_bytes, &mut output_bytes).unwrap();
codec
.compress(input_bytes, &mut output_bytes, &mut Default::default())
.unwrap();
let result = codec
.decompress(output_bytes.as_slice(), input_bytes.len())
.unwrap();
Expand All @@ -240,7 +289,9 @@ mod tests {
let input_bytes = b"hello zstd";
let codec = super::CompressionCodec::Zstd;
let mut output_bytes: Vec<u8> = Vec::new();
codec.compress(input_bytes, &mut output_bytes).unwrap();
codec
.compress(input_bytes, &mut output_bytes, &mut Default::default())
.unwrap();
let result = codec
.decompress(output_bytes.as_slice(), input_bytes.len())
.unwrap();
Expand Down
14 changes: 12 additions & 2 deletions arrow-ipc/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2702,7 +2702,12 @@ mod tests {
let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let (_, encoded) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.encode(
&batch,
&mut dict_tracker,
&Default::default(),
&mut Default::default(),
)
.unwrap();

let message = root_as_message(&encoded.ipc_message).unwrap();
Expand Down Expand Up @@ -2740,7 +2745,12 @@ mod tests {
let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let (_, encoded) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.encode(
&batch,
&mut dict_tracker,
&Default::default(),
&mut Default::default(),
)
.unwrap();

let message = root_as_message(&encoded.ipc_message).unwrap();
Expand Down
Loading
Loading