@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
2222import org .apache .commons .lang3 .StringUtils
2323import org .apache .hadoop .fs .{BlockLocation , FileStatus , LocatedFileStatus , Path }
2424
25+ import org .apache .spark .TaskContext
2526import org .apache .spark .rdd .RDD
2627import org .apache .spark .sql .SparkSession
2728import org .apache .spark .sql .catalyst .{InternalRow , TableIdentifier }
@@ -32,12 +33,13 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
3233import org .apache .spark .sql .catalyst .plans .physical .{HashPartitioning , Partitioning , UnknownPartitioning }
3334import org .apache .spark .sql .execution .datasources ._
3435import org .apache .spark .sql .execution .datasources .parquet .{ParquetFileFormat => ParquetSource }
35- import org .apache .spark .sql .execution .metric .SQLMetrics
36+ import org .apache .spark .sql .execution .metric .{ SQLMetric , SQLMetrics }
3637import org .apache .spark .sql .sources .{BaseRelation , Filter }
3738import org .apache .spark .sql .types .StructType
38- import org .apache .spark .util .Utils
39+ import org .apache .spark .util .{ TaskCompletionListener , Utils }
3940import org .apache .spark .util .collection .BitSet
4041
42+
4143trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
4244 val relation : BaseRelation
4345 val tableIdentifier : Option [TableIdentifier ]
@@ -332,6 +334,7 @@ case class FileSourceScanExec(
332334
333335 override lazy val metrics =
334336 Map (" numOutputRows" -> SQLMetrics .createMetric(sparkContext, " number of output rows" ),
337+ " readBytes" -> SQLMetrics .createMetric(sparkContext, " number of read bytes" ),
335338 " numFiles" -> SQLMetrics .createMetric(sparkContext, " number of files" ),
336339 " metadataTime" -> SQLMetrics .createMetric(sparkContext, " metadata time (ms)" ),
337340 " scanTime" -> SQLMetrics .createTimingMetric(sparkContext, " scan time" ))
@@ -344,9 +347,10 @@ case class FileSourceScanExec(
344347 WholeStageCodegenExec (this )(codegenStageId = 0 ).execute()
345348 } else {
346349 val numOutputRows = longMetric(" numOutputRows" )
347-
350+ val readBytes = longMetric( " readBytes " )
348351 if (needsUnsafeRowConversion) {
349352 inputRDD.mapPartitionsWithIndexInternal { (index, iter) =>
353+ addReadBytesListener(readBytes)
350354 val proj = UnsafeProjection .create(schema)
351355 proj.initialize(index)
352356 iter.map( r => {
@@ -355,9 +359,12 @@ case class FileSourceScanExec(
355359 })
356360 }
357361 } else {
358- inputRDD.map { r =>
359- numOutputRows += 1
360- r
362+ inputRDD.mapPartitions { iter =>
363+ addReadBytesListener(readBytes)
364+ iter.map { r =>
365+ numOutputRows += 1
366+ r
367+ }
361368 }
362369 }
363370 }
@@ -534,4 +541,27 @@ case class FileSourceScanExec(
534541 QueryPlan .normalizePredicates(dataFilters, output),
535542 None )
536543 }
544+
545+ protected override def doProduce (ctx : CodegenContext ): String = {
546+ val readBytes = metricTerm(ctx, " readBytes" )
547+ ctx.addPartitionInitializationStatement(
548+ s """
549+ | org.apache.spark.TaskContext.get()
550+ | .addTaskCompletionListener(new org.apache.spark.util.TaskCompletionListener() {
551+ | @Override
552+ | public void onTaskCompletion(org.apache.spark.TaskContext context) {
553+ | $readBytes.add(context.taskMetrics().inputMetrics().bytesRead());
554+ | }
555+ | });
556+ """ .stripMargin)
557+ super .doProduce(ctx)
558+ }
559+
560+ private def addReadBytesListener (metric : SQLMetric ): Unit = {
561+ TaskContext .get().addTaskCompletionListener(new TaskCompletionListener {
562+ override def onTaskCompletion (context : TaskContext ): Unit = {
563+ metric.add(context.taskMetrics().inputMetrics.bytesRead)
564+ }
565+ })
566+ }
537567}
0 commit comments