Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use jni::{
sys::{jbyteArray, jint, jlong, jlongArray},
JNIEnv,
};
use std::time::{Duration, Instant};
use std::{collections::HashMap, sync::Arc, task::Poll};

use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
Expand Down Expand Up @@ -81,6 +82,8 @@ struct ExecutionContext {
pub runtime: Runtime,
/// Native metrics
pub metrics: Arc<GlobalRef>,
/// The time it took to create the native plan and configure the context
pub plan_creation_time: Duration,
/// DataFusion SessionContext
pub session_ctx: Arc<SessionContext>,
/// Whether to enable additional debugging checks & messages
Expand Down Expand Up @@ -109,6 +112,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// Init JVM classes
JVMClasses::init(&mut env);

let start = Instant::now();

let array = unsafe { JPrimitiveArray::from_raw(serialized_query) };
let bytes = env.convert_byte_array(array)?;

Expand Down Expand Up @@ -167,6 +172,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// dictionaries will be dropped as well.
let session = prepare_datafusion_session_context(&configs, task_memory_manager)?;

let plan_creation_time = start.elapsed();

let exec_context = Box::new(ExecutionContext {
id,
spark_plan,
Expand All @@ -177,6 +184,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
conf: configs,
runtime,
metrics,
plan_creation_time,
session_ctx: Arc::new(session),
debug_native,
explain_native,
Expand Down Expand Up @@ -321,6 +329,8 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr
pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
e: JNIEnv,
_class: JClass,
stage_id: jint,
partition: jint,
exec_context: jlong,
array_addrs: jlongArray,
schema_addrs: jlongArray,
Expand All @@ -335,20 +345,23 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
// Because we don't know if input arrays are dictionary-encoded when we create
// query plan, we need to defer stream initialization to first time execution.
if exec_context.root_op.is_none() {
let start = Instant::now();
let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx))
.with_exec_id(exec_context_id);
let (scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
&mut exec_context.input_sources.clone(),
)?;
let physical_plan_time = start.elapsed();

exec_context.plan_creation_time += physical_plan_time;
exec_context.root_op = Some(Arc::clone(&root_op));
exec_context.scans = scans;

if exec_context.explain_native {
let formatted_plan_str =
DisplayableExecutionPlan::new(root_op.as_ref()).indent(true);
info!("Comet native query plan:\n {formatted_plan_str:}");
info!("Comet native query plan:\n{formatted_plan_str:}");
}

let task_ctx = exec_context.session_ctx.task_ctx();
Expand Down Expand Up @@ -388,7 +401,12 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
if let Some(plan) = &exec_context.root_op {
let formatted_plan_str =
DisplayableExecutionPlan::with_metrics(plan.as_ref()).indent(true);
info!("Comet native query plan with metrics:\n{formatted_plan_str:}");
info!(
"Comet native query plan with metrics:\
\n[Stage {} Partition {}] plan creation (including CometScans fetching first batches) took {:?}:\
\n{formatted_plan_str:}",
stage_id, partition, exec_context.plan_creation_time
);
}
}

Expand Down
23 changes: 12 additions & 11 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use futures::Stream;
use itertools::Itertools;
use std::rc::Rc;
use std::{
any::Any,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};

use crate::{
errors::CometError,
execution::{
Expand All @@ -48,9 +38,18 @@ use datafusion::{
physical_plan::{ExecutionPlan, *},
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult};
use futures::Stream;
use itertools::Itertools;
use jni::objects::JValueGen;
use jni::objects::{GlobalRef, JObject};
use jni::sys::jsize;
use std::rc::Rc;
use std::{
any::Any,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};

/// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file
/// scan or the result of reading a broadcast or shuffle exchange.
Expand Down Expand Up @@ -98,7 +97,6 @@ impl ScanExec {
let batch =
ScanExec::get_next(exec_context_id, input_source.as_obj(), data_types.len())?;
timer.stop();
baseline_metrics.record_output(batch.num_rows());
batch
} else {
InputBatch::EOF
Expand Down Expand Up @@ -162,6 +160,7 @@ impl ScanExec {
// This is a unit test. We don't need to call JNI.
return Ok(());
}
let mut timer = self.baseline_metrics.elapsed_compute().timer();

let mut current_batch = self.batch.try_lock().unwrap();
if current_batch.is_none() {
Expand All @@ -173,6 +172,8 @@ impl ScanExec {
*current_batch = Some(next_batch);
}

timer.stop();

Ok(())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class CometExecIterator(
nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
nativeLib.executePlan(plan, arrayAddrs, schemaAddrs)
val ctx = TaskContext.get()
nativeLib.executePlan(ctx.stageId(), ctx.partitionId(), plan, arrayAddrs, schemaAddrs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it to track times for specific partitions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because we create one plan per partition we have a lot of repeated plans and it is helpful to see the partition number.

})
}

Expand Down
11 changes: 10 additions & 1 deletion spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class Native extends NativeBase {
/**
* Execute a native query plan based on given input Arrow arrays.
*
* @param stage
* the stage ID, for informational purposes
* @param partition
* the partition ID, for informational purposes
* @param plan
* the address to native query plan.
* @param arrayAddrs
Expand All @@ -65,7 +69,12 @@ class Native extends NativeBase {
* @return
* the number of rows, if -1, it means end of the output.
*/
@native def executePlan(plan: Long, arrayAddrs: Array[Long], schemaAddrs: Array[Long]): Long
@native def executePlan(
stage: Int,
partition: Int,
plan: Long,
arrayAddrs: Array[Long],
schemaAddrs: Array[Long]): Long

/**
* Release and drop the native query plan object and context object.
Expand Down
Loading