Skip to content

Commit 8ce9bb5

Browse files
committed
Implement native decoding and decompression
1 parent 053b7cc commit 8ce9bb5

File tree

14 files changed

+291
-239
lines changed

14 files changed

+291
-239
lines changed

native/core/src/execution/jni_api.rs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! Define JNI APIs which can be called from Java/Scala.
1919
20+
use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
2021
use arrow::datatypes::DataType as ArrowDataType;
2122
use arrow_array::RecordBatch;
2223
use datafusion::{
@@ -40,8 +41,6 @@ use jni::{
4041
use std::time::{Duration, Instant};
4142
use std::{collections::HashMap, sync::Arc, task::Poll};
4243

43-
use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
44-
4544
use crate::{
4645
errors::{try_unwrap_or_throw, CometError, CometResult},
4746
execution::{
@@ -53,13 +52,15 @@ use crate::{
5352
use datafusion_comet_proto::spark_operator::Operator;
5453
use datafusion_common::ScalarValue;
5554
use futures::stream::StreamExt;
55+
use jni::objects::JByteBuffer;
5656
use jni::{
5757
objects::GlobalRef,
5858
sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
5959
};
6060
use tokio::runtime::Runtime;
6161

6262
use crate::execution::operators::ScanExec;
63+
use crate::execution::shuffle::read_ipc_compressed;
6364
use crate::execution::spark_plan::SparkPlan;
6465
use log::info;
6566

@@ -95,7 +96,7 @@ struct ExecutionContext {
9596

9697
/// Accept serialized query plan and return the address of the native query plan.
9798
/// # Safety
98-
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
99+
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
99100
#[no_mangle]
100101
pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
101102
e: JNIEnv,
@@ -231,7 +232,7 @@ fn prepare_output(
231232
array_addrs: jlongArray,
232233
schema_addrs: jlongArray,
233234
output_batch: RecordBatch,
234-
exec_context: &mut ExecutionContext,
235+
validate: bool,
235236
) -> CometResult<jlong> {
236237
let array_address_array = unsafe { JLongArray::from_raw(array_addrs) };
237238
let num_cols = env.get_array_length(&array_address_array)? as usize;
@@ -255,7 +256,7 @@ fn prepare_output(
255256
)));
256257
}
257258

258-
if exec_context.debug_native {
259+
if validate {
259260
// Validate the output arrays.
260261
for array in results.iter() {
261262
let array_data = array.to_data();
@@ -275,9 +276,6 @@ fn prepare_output(
275276
i += 1;
276277
}
277278

278-
// Update metrics
279-
update_metrics(env, exec_context)?;
280-
281279
Ok(num_rows as jlong)
282280
}
283281

@@ -298,7 +296,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr
298296
/// Accept serialized query plan and the addresses of Arrow Arrays from Spark,
299297
/// then execute the query. Return addresses of arrow vector.
300298
/// # Safety
301-
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
299+
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
302300
#[no_mangle]
303301
pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
304302
e: JNIEnv,
@@ -356,22 +354,22 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
356354
let next_item = exec_context.stream.as_mut().unwrap().next();
357355
let poll_output = exec_context.runtime.block_on(async { poll!(next_item) });
358356

357+
// Update metrics
358+
update_metrics(&mut env, exec_context)?;
359+
359360
match poll_output {
360361
Poll::Ready(Some(output)) => {
362+
// prepare output for FFI transfer
361363
return prepare_output(
362364
&mut env,
363365
array_addrs,
364366
schema_addrs,
365367
output?,
366-
exec_context,
368+
exec_context.debug_native,
367369
);
368370
}
369371
Poll::Ready(None) => {
370372
// Reaches EOF of output.
371-
372-
// Update metrics
373-
update_metrics(&mut env, exec_context)?;
374-
375373
if exec_context.explain_native {
376374
if let Some(plan) = &exec_context.root_op {
377375
let formatted_plan_str =
@@ -391,9 +389,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
391389
// A poll pending means there are more than one blocking operators,
392390
// we don't need go back-forth between JVM/Native. Just keeping polling.
393391
Poll::Pending => {
394-
// Update metrics
395-
update_metrics(&mut env, exec_context)?;
396-
397392
// Pull input batches
398393
pull_input_batches(exec_context)?;
399394

@@ -459,7 +454,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext {
459454

460455
/// Used by Comet shuffle external sorter to write sorted records to disk.
461456
/// # Safety
462-
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
457+
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
463458
#[no_mangle]
464459
pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative(
465460
e: JNIEnv,
@@ -544,3 +539,23 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(
544539
Ok(())
545540
})
546541
}
542+
543+
#[no_mangle]
544+
/// Used by Comet native shuffle reader
545+
/// # Safety
546+
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
547+
pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
548+
e: JNIEnv,
549+
_class: JClass,
550+
byte_buffer: JByteBuffer,
551+
array_addrs: jlongArray,
552+
schema_addrs: jlongArray,
553+
) -> jlong {
554+
try_unwrap_or_throw(&e, |mut env| {
555+
let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?;
556+
let length = env.get_direct_buffer_capacity(&byte_buffer)?;
557+
let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) };
558+
let batch = read_ipc_compressed(slice)?;
559+
prepare_output(&mut env, array_addrs, schema_addrs, batch, false)
560+
})
561+
}

native/core/src/execution/shuffle/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ mod list;
1919
mod map;
2020
pub mod row;
2121
mod shuffle_writer;
22-
pub use shuffle_writer::{write_ipc_compressed, CompressionCodec, ShuffleWriterExec};
22+
pub use shuffle_writer::{
23+
read_ipc_compressed, write_ipc_compressed, CompressionCodec, ShuffleWriterExec,
24+
};

native/core/src/execution/shuffle/shuffle_writer.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::{
2121
common::bit::ceil,
2222
errors::{CometError, CometResult},
2323
};
24+
use arrow::ipc::reader::StreamReader;
2425
use arrow::{datatypes::*, ipc::writer::StreamWriter};
2526
use async_trait::async_trait;
2627
use bytes::Buf;
@@ -1555,7 +1556,7 @@ pub enum CompressionCodec {
15551556
pub fn write_ipc_compressed<W: Write + Seek>(
15561557
batch: &RecordBatch,
15571558
output: &mut W,
1558-
codec: &CompressionCodec,
1559+
compression_codec: &CompressionCodec,
15591560
ipc_time: &Time,
15601561
) -> Result<usize> {
15611562
if batch.num_rows() == 0 {
@@ -1565,10 +1566,14 @@ pub fn write_ipc_compressed<W: Write + Seek>(
15651566
let mut timer = ipc_time.timer();
15661567
let start_pos = output.stream_position()?;
15671568

1568-
// write ipc_length placeholder
1569+
// write message length placeholder
15691570
output.write_all(&[0u8; 8])?;
15701571

1571-
let output = match codec {
1572+
// write number of columns because JVM side needs to know how many addresses to allocate
1573+
let field_count = batch.schema().fields().len();
1574+
output.write_all(&field_count.to_le_bytes())?;
1575+
1576+
let output = match compression_codec {
15721577
CompressionCodec::None => {
15731578
let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?;
15741579
arrow_writer.write(batch)?;
@@ -1587,18 +1592,25 @@ pub fn write_ipc_compressed<W: Write + Seek>(
15871592

15881593
// fill ipc length
15891594
let end_pos = output.stream_position()?;
1590-
let ipc_length = end_pos - start_pos - 8;
1595+
let compressed_length = end_pos - start_pos - 8;
15911596

15921597
// fill ipc length
15931598
output.seek(SeekFrom::Start(start_pos))?;
1594-
output.write_all(&ipc_length.to_le_bytes()[..])?;
1599+
output.write_all(&compressed_length.to_le_bytes()[..])?;
15951600
output.seek(SeekFrom::Start(end_pos))?;
15961601

15971602
timer.stop();
15981603

15991604
Ok((end_pos - start_pos) as usize)
16001605
}
16011606

1607+
pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
1608+
let decoder = zstd::Decoder::new(bytes)?;
1609+
let mut reader = StreamReader::try_new(decoder, None)?;
1610+
// TODO check for None
1611+
reader.next().unwrap().map_err(|e| e.into())
1612+
}
1613+
16021614
/// A stream that yields no record batches which represent end of output.
16031615
pub struct EmptyStream {
16041616
/// Schema representing the data
@@ -1648,18 +1660,23 @@ mod test {
16481660

16491661
#[test]
16501662
#[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx`
1651-
fn write_ipc_zstd() {
1663+
fn roundtrip_ipc_zstd() {
16521664
let batch = create_batch(8192);
16531665
let mut output = vec![];
16541666
let mut cursor = Cursor::new(&mut output);
1655-
write_ipc_compressed(
1667+
let length = write_ipc_compressed(
16561668
&batch,
16571669
&mut cursor,
16581670
&CompressionCodec::Zstd(1),
16591671
&Time::default(),
16601672
)
16611673
.unwrap();
1662-
assert_eq!(40218, output.len());
1674+
assert_eq!(40226, output.len());
1675+
assert_eq!(40226, length);
1676+
1677+
let ipc_without_length_prefix = &output[16..];
1678+
let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap();
1679+
assert_eq!(batch, batch2);
16631680
}
16641681

16651682
#[test]

spark/src/main/scala/org/apache/comet/Native.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
package org.apache.comet
2121

22+
import java.nio.ByteBuffer
23+
2224
import org.apache.spark.CometTaskMemoryManager
2325
import org.apache.spark.sql.comet.CometMetricNode
2426

@@ -139,4 +141,19 @@ class Native extends NativeBase {
139141
* the size of the array.
140142
*/
141143
@native def sortRowPartitionsNative(addr: Long, size: Long): Unit
144+
145+
/**
146+
* Decompress and decode a native shuffle block.
147+
* @param shuffleBlock
148+
* the encoded anc compressed shuffle block.
149+
* @param addr
150+
* the address of the array of compressed and encoded bytes.
151+
* @param size
152+
* the size of the array.
153+
*/
154+
@native def decodeShuffleBlock(
155+
shuffleBlock: ByteBuffer,
156+
arrayAddrs: Array[Long],
157+
schemaAddrs: Array[Long]): Long
158+
142159
}

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,24 @@
1919

2020
package org.apache.spark.sql.comet
2121

22+
import java.io.DataInputStream
23+
import java.nio.channels.Channels
2224
import java.util.UUID
2325
import java.util.concurrent.{Future, TimeoutException, TimeUnit}
2426

2527
import scala.concurrent.{ExecutionContext, Promise}
2628
import scala.concurrent.duration.NANOSECONDS
2729
import scala.util.control.NonFatal
2830

29-
import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
31+
import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, TaskContext}
3032
import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
33+
import org.apache.spark.io.CompressionCodec
3134
import org.apache.spark.launcher.SparkLauncher
3235
import org.apache.spark.rdd.RDD
3336
import org.apache.spark.sql.catalyst.InternalRow
3437
import org.apache.spark.sql.catalyst.expressions.Attribute
3538
import org.apache.spark.sql.catalyst.plans.logical.Statistics
39+
import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
3640
import org.apache.spark.sql.errors.QueryExecutionErrors
3741
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
3842
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
@@ -299,8 +303,24 @@ class CometBatchRDD(
299303
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
300304
val partition = split.asInstanceOf[CometBatchPartition]
301305
partition.value.value.toIterator
302-
.flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName))
306+
.flatMap(decodeBatches(_, this.getClass.getSimpleName))
303307
}
308+
309+
/**
310+
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
311+
*/
312+
private def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = {
313+
if (bytes.size == 0) {
314+
return Iterator.empty
315+
}
316+
317+
// decompress with Spark codec not Comet so this is not compatible with shuffle
318+
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
319+
val cbbis = bytes.toInputStream()
320+
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
321+
new ArrowReaderIterator(Channels.newChannel(ins), source)
322+
}
323+
304324
}
305325

306326
class CometBatchPartition(

spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ object CometMetricNode {
136136
"mempool_time" -> SQLMetrics.createNanoTimingMetric(sc, "memory pool time"),
137137
"repart_time" -> SQLMetrics.createNanoTimingMetric(sc, "repartition time"),
138138
"ipc_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and compression time"),
139+
"decodeTime" -> SQLMetrics.createNanoTimingMetric(sc, "decoding and decompression time"),
139140
"spill_count" -> SQLMetrics.createMetric(sc, "number of spills"),
140141
"spilled_bytes" -> SQLMetrics.createMetric(sc, "spilled bytes"),
141142
"input_batches" -> SQLMetrics.createMetric(sc, "number of input batches"))

0 commit comments

Comments
 (0)