Skip to content

Commit

Permalink
Spark datasource autologging: Improve fidelity of REPL ID / context i…
Browse files Browse the repository at this point in the history
…nformation (mlflow#4551)

* Initial commit

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Test case for getReplId

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Java test case

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Whitespace

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Format

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Comment and ordering

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Test case fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Address review comments

Signed-off-by: dbczumar <corey.zumar@databricks.com>
  • Loading branch information
dbczumar authored Jul 13, 2021
1 parent c991b99 commit adcdf30
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 8 deletions.
3 changes: 2 additions & 1 deletion mlflow/_spark_autologging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
autologging_is_disabled,
ExceptionSafeClass,
)
from mlflow.utils.databricks_utils import get_repl_id as get_databricks_repl_id
from mlflow.spark import FLAVOR_NAME

_JAVA_PACKAGE = "org.mlflow.spark.autologging"
Expand Down Expand Up @@ -156,7 +157,7 @@ def _get_repl_id():
local properties, and expect that the PythonSubscriber for the current Python process only
receives events for datasource reads triggered by the current process.
"""
repl_id = SparkContext.getOrCreate().getLocalProperty("spark.databricks.replId")
repl_id = get_databricks_repl_id()
if repl_id:
return repl_id
main_file = sys.argv[0] if len(sys.argv) > 0 else "<console>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.SparkSession
import org.slf4j.LoggerFactory

import scala.util.{Try, Success, Failure}
import scala.util.control.NonFatal

/**
Expand Down Expand Up @@ -42,15 +43,35 @@ private[autologging] trait MlflowAutologEventPublisherImpl {
"autologging."))
}

// Exposed for testing
private[autologging] def getSparkDataSourceListener: SparkListener = {
// Get SparkContext & determine if REPL id is set - if not, then we log irrespective of repl
// ID, but if so, we log conditionally on repl ID
/**
* @returns True if Spark is running in a REPL-aware context. False otherwise.
*/
private def isInReplAwareContext: Boolean = {
// Attempt to fetch the `spark.databricks.replId` property from the Spark Context.
// The presence of this ID is a clear indication that we are in a REPL-aware environment
val sc = spark.sparkContext
val replId = Option(sc.getLocalProperty("spark.databricks.replId"))
replId match {
case None => new SparkDataSourceListener(this)
case Some(_) => new ReplAwareSparkDataSourceListener(this)
if (replId.isDefined) {
return true
}

// If the `spark.databricks.replId` is absent, we may still be in a Databricks environment,
// which is REPL-aware. To check, we look for the presence of a Databricks-specific cluster ID
// tag in the Spark configuration
val clusterId = spark.conf.getOption("spark.databricks.clusterUsageTags.clusterId")
if (clusterId.isDefined) {
return true
}

false
}

// Exposed for testing
private[autologging] def getSparkDataSourceListener: SparkListener = {
if (isInReplAwareContext) {
new ReplAwareSparkDataSourceListener(this)
} else {
new SparkDataSourceListener(this)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,14 @@ class SparkAutologgingSuite extends FunSuite with Matchers with BeforeAndAfterAl
assert(MlflowAutologEventPublisher.sparkQueryListener.isInstanceOf[SparkDataSourceListener])
}

test("Delegates to repl-ID-aware listener if Databricks cluster ID is set in Spark Conf") {
// Verify instance created by init() in beforeEach is not REPL-ID-aware
assert(MlflowAutologEventPublisher.sparkQueryListener.isInstanceOf[SparkDataSourceListener])
assert(!MlflowAutologEventPublisher.sparkQueryListener.isInstanceOf[ReplAwareSparkDataSourceListener])
MlflowAutologEventPublisher.stop()

spark.conf.set("spark.databricks.clusterUsageTags.clusterId", "myCoolClusterId")
MlflowAutologEventPublisher.init()
assert(MlflowAutologEventPublisher.sparkQueryListener.isInstanceOf[ReplAwareSparkDataSourceListener])
}
}
27 changes: 27 additions & 0 deletions mlflow/utils/databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,33 @@ def get_job_group_id():
return None


def get_repl_id():
"""
:return: The ID of the current Databricks Python REPL
"""
# Attempt to fetch the REPL ID from the Python REPL's entrypoint object. This REPL ID
# is guaranteed to be set upon REPL startup in DBR / MLR 9.0
try:
dbutils = _get_dbutils()
repl_id = dbutils.entry_point.getReplId()
if repl_id is not None:
return repl_id
except Exception:
pass

# If the REPL ID entrypoint property is unavailable due to an older runtime version (< 9.0),
# attempt to fetch the REPL ID from the Spark Context. This property may not be available
# until several seconds after REPL startup
try:
from pyspark import SparkContext

repl_id = SparkContext.getOrCreate().getLocalProperty("spark.databricks.replId")
if repl_id is not None:
return repl_id
except Exception:
pass


def get_job_id():
try:
return _get_command_context().jobId().get()
Expand Down
32 changes: 32 additions & 0 deletions tests/utils/test_databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,35 @@ def test_is_in_databricks_runtime():
# pylint: disable=unused-import,import-error,no-name-in-module,unused-variable
import pyspark.databricks
assert not databricks_utils.is_in_databricks_runtime()


def test_get_repl_id():
# Outside of Databricks environments, the Databricks REPL ID should be absent
assert databricks_utils.get_repl_id() is None

mock_dbutils = mock.MagicMock()
mock_dbutils.entry_point.getReplId.return_value = "testReplId1"
with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
assert databricks_utils.get_repl_id() == "testReplId1"

mock_sparkcontext_inst = mock.MagicMock()
mock_sparkcontext_inst.getLocalProperty.return_value = "testReplId2"
mock_sparkcontext_class = mock.MagicMock()
mock_sparkcontext_class.getOrCreate.return_value = mock_sparkcontext_inst
mock_spark = mock.MagicMock()
mock_spark.SparkContext = mock_sparkcontext_class

import builtins

original_import = builtins.__import__

def mock_import(name, *args, **kwargs):
if name == "pyspark":
return mock_spark
else:
return original_import(name, *args, **kwargs)

with mock.patch(
"builtins.__import__", side_effect=mock_import,
):
assert databricks_utils.get_repl_id() == "testReplId2"

0 comments on commit adcdf30

Please sign in to comment.