Skip to content

Avoid Option while generating call site #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,8 @@ class SparkContext(
* has overridden the call site, this will return the user's version.
*/
private[spark] def getCallSite(): String = {
Option(getLocalProperty("externalCallSite")).getOrElse(Utils.formatCallSiteInfo())
val defaultCallSite = Utils.getCallSiteInfo
Option(getLocalProperty("externalCallSite")).getOrElse(defaultCallSite.toString)
}

/**
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ abstract class RDD[T: ClassTag](

/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
@transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo
private[spark] def getCreationSite = Utils.formatCallSiteInfo(creationSiteInfo)
private[spark] def getCreationSite: String = creationSiteInfo.toString

private[spark] def elementClassTag: ClassTag[T] = classTag[T]

Expand Down
18 changes: 9 additions & 9 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -676,16 +676,22 @@ private[spark] object Utils extends Logging {
private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r

private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
val firstUserLine: Int, val firstUserClass: String)
val firstUserLine: Int, val firstUserClass: String) {

/** Returns a printable version of the call site info suitable for logs. */
override def toString = {
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
}
}

/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
val trace = Thread.currentThread.getStackTrace()
.filterNot(_.getMethodName.contains("getStackTrace"))

// Keep crawling up the stack trace until we find the first function not inside of the spark
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
Expand Down Expand Up @@ -718,12 +724,6 @@ private[spark] object Utils extends Logging {
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}

/** Returns a printable version of the call site info suitable for logs. */
def formatCallSiteInfo(callSiteInfo: CallSiteInfo = Utils.getCallSiteInfo) = {
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
callSiteInfo.firstUserLine)
}

/** Return a string containing part of a file from byte 'start' to 'end'. */
def offsetBytes(path: String, start: Long, end: Long): String = {
val file = new File(path)
Expand Down
36 changes: 35 additions & 1 deletion core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark

import org.scalatest.FunSuite
import org.scalatest.{Assertions, FunSuite}

class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
test("getPersistentRDDs only returns RDDs that are marked as cached") {
Expand Down Expand Up @@ -56,4 +56,38 @@ class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
rdd.collect()
assert(sc.getRDDStorageInfo.size === 1)
}

test("call sites report correct locations") {
sc = new SparkContext("local", "test")
testPackage.runCallSiteTest(sc)
}
}

/** Call site must be outside of usual org.apache.spark packages (see Utils#SPARK_CLASS_REGEX). */
package object testPackage extends Assertions {
private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

def runCallSiteTest(sc: SparkContext) {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
val rddCreationSite = rdd.getCreationSite
val curCallSite = sc.getCallSite() // note: 2 lines after definition of "rdd"

val rddCreationLine = rddCreationSite match {
case CALL_SITE_REGEX(func, file, line) => {
assert(func === "makeRDD")
assert(file === "SparkContextInfoSuite.scala")
line.toInt
}
case _ => fail("Did not match expected call site format")
}

curCallSite match {
case CALL_SITE_REGEX(func, file, line) => {
assert(func === "getCallSite") // this is correct because we called it from outside of Spark
assert(file === "SparkContextInfoSuite.scala")
assert(line.toInt === rddCreationLine.toInt + 2)
}
case _ => fail("Did not match expected call site format")
}
}
}