@@ -24,6 +24,7 @@ import java.util.{Arrays, Comparator, Date, Locale}
24
24
import java .util .concurrent .ConcurrentHashMap
25
25
26
26
import scala .collection .JavaConverters ._
27
+ import scala .collection .mutable
27
28
import scala .util .control .NonFatal
28
29
29
30
import com .google .common .primitives .Longs
@@ -148,13 +149,25 @@ class SparkHadoopUtil extends Logging {
148
149
private [spark] def getFSBytesReadOnThreadCallback (): () => Long = {
149
150
val f = () => FileSystem .getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum
150
151
val baseline = (Thread .currentThread().getId, f())
151
- val bytesReadMap = new ConcurrentHashMap [Long , Long ]()
152
152
153
- () => {
154
- bytesReadMap.put(Thread .currentThread().getId, f())
155
- bytesReadMap.asScala.map { case (k, v) =>
156
- v - (if (k == baseline._1) baseline._2 else 0 )
157
- }.sum
153
+ new Function0 [Long ] {
154
+ private val bytesReadMap = new mutable.HashMap [Long , Long ]()
155
+
156
+ /**
157
+ * Returns a function that can be called to calculate Hadoop FileSystem bytes read.
158
+ * This function may be called in both spawned child threads and parent task thread (in
159
+ * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics.
160
+ * So we need a map to track the bytes read from the child threads and parent thread,
161
+ * summing them together to get the bytes read of this task.
162
+ */
163
+ override def apply (): Long = {
164
+ bytesReadMap.synchronized {
165
+ bytesReadMap.put(Thread .currentThread().getId, f())
166
+ bytesReadMap.map { case (k, v) =>
167
+ v - (if (k == baseline._1) baseline._2 else 0 )
168
+ }.sum
169
+ }
170
+ }
158
171
}
159
172
}
160
173
0 commit comments