Skip to content

Commit 9366091

Browse files
dansanduleacrobert3005
authored andcommitted
Add conda support for R (apache#261)
1 parent 0d26e34 commit 9366091

File tree

18 files changed

+222
-74
lines changed

18 files changed

+222
-74
lines changed

R/pkg/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,5 @@ Collate:
5757
'types.R'
5858
'utils.R'
5959
'window.R'
60-
RoxygenNote: 5.0.1
60+
RoxygenNote: 6.0.1
6161
VignetteBuilder: knitr

R/pkg/R/RDD.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
168168
serializedFuncArr,
169169
rdd@env$prev_serializedMode,
170170
packageNamesArr,
171+
spark.buildCondaInstructions(),
171172
broadcastArr,
172173
callJMethod(prev_jrdd, "classTag"))
173174
} else {
@@ -177,6 +178,7 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
177178
rdd@env$prev_serializedMode,
178179
serializedMode,
179180
packageNamesArr,
181+
spark.buildCondaInstructions(),
180182
broadcastArr,
181183
callJMethod(prev_jrdd, "classTag"))
182184
}

R/pkg/R/context.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,13 @@ spark.addFile <- function(path, recursive = FALSE) {
319319
invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)), recursive))
320320
}
321321

322+
#' Construct condaBuildInstructions used to re-create the driver's conda
323+
#' environment on executors.
324+
spark.buildCondaInstructions <- function() {
325+
sc <- callJMethod(getSparkContext(), "sc")
326+
callJMethod(sc, "buildCondaInstructions")
327+
}
328+
322329
#' Get the root directory that contains files added through spark.addFile.
323330
#'
324331
#' @rdname spark.getSparkFilesRootDirectory

R/pkg/tests/fulltests/test_includePackage.R

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
context("include R packages")
1919

2020
# JavaSparkContext handle
21-
sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
21+
config <- list(spark.conda.channelUrls = "https://repo.continuum.io/pkgs/r/",
22+
spark.conda.bootstrapPackages = "r,plyr")
23+
sparkSession <- sparkR.session(master = sparkRTestMaster,
24+
enableHiveSupport = FALSE, sparkConfig = config)
2225
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
2326

2427
# Partitioned data
@@ -27,18 +30,17 @@ rdd <- parallelize(sc, nums, 2L)
2730

2831
test_that("include inside function", {
2932
# Only run the test if plyr is installed.
30-
if ("plyr" %in% rownames(installed.packages())) {
31-
suppressPackageStartupMessages(library(plyr))
32-
generateData <- function(x) {
33-
suppressPackageStartupMessages(library(plyr))
34-
attach(airquality)
35-
result <- transform(Ozone, logOzone = log(Ozone))
36-
result
37-
}
3833

39-
data <- lapplyPartition(rdd, generateData)
40-
actual <- collectRDD(data)
34+
suppressPackageStartupMessages(library(plyr))
35+
generateData <- function(x) {
36+
suppressPackageStartupMessages(library(plyr))
37+
attach(airquality)
38+
result <- transform(Ozone, logOzone = log(Ozone))
39+
result
4140
}
41+
42+
data <- lapplyPartition(rdd, generateData)
43+
actual <- collectRDD(data)
4244
})
4345

4446
test_that("use include package", {

R/run-tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ FAILED=0
2323
LOGFILE=$FWDIR/unit-tests.out
2424
rm -f $LOGFILE
2525

26-
SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
26+
SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --num-executors 1 --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.conda.binaryPath=$CONDA_BIN --conf spark.conda.bootstrapPackages="r,r-essentials,r-plyr,r-testthat" --conf spark.conda.channelUrls="https://repo.continuum.io/pkgs/r,https://repo.continuum.io/pkgs/main,https://repo.continuum.io/pkgs/free,https://repo.continuum.io/pkgs/pro" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
2727
FAILED=$((PIPESTATUS[0]||$FAILED))
2828

2929
NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)"

circle.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@ machine:
22
java:
33
version: oraclejdk8
44
post:
5-
- sudo apt-get --assume-yes install r-base r-base-dev
5+
- sudo sh -c 'echo "deb http://cran.rstudio.com/bin/linux/ubuntu trusty/" >> /etc/apt/sources.list'
6+
- gpg --keyserver keyserver.ubuntu.com --recv-key E084DAB9
7+
- gpg -a --export E084DAB9 | sudo apt-key add -
8+
- sudo apt-get update
9+
- sudo apt-get --assume-yes install r-base r-base-dev qpdf
10+
- sudo chmod 777 /usr/local/lib/R/site-library
11+
- /usr/lib/R/bin/R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival', 'devtools', 'roxygen2', 'lintr'), repos='http://cran.us.r-project.org', lib='/usr/local/lib/R/site-library')"
612
- |
713
if [[ ! -d ${CONDA_ROOT} ]]; then
814
echo "Installing Miniconda...";

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark
1919

20-
import java.io.File
20+
import java.io.{DataOutputStream, File, IOException}
2121
import java.net.Socket
2222
import java.util.Locale
2323

@@ -76,6 +76,7 @@ class SparkEnv (
7676
case class PythonWorkerKey(pythonExec: Option[String], envVars: Map[String, String],
7777
condaInstructions: Option[CondaSetupInstructions])
7878
private val pythonWorkers = mutable.HashMap[PythonWorkerKey, PythonWorkerFactory]()
79+
private var rDaemonChannel: DataOutputStream = _
7980

8081
// A general, soft-reference map for metadata needed during HadoopRDD split computation
8182
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
@@ -88,6 +89,7 @@ class SparkEnv (
8889
if (!isStopped) {
8990
isStopped = true
9091
pythonWorkers.values.foreach(_.stop())
92+
destroyRDaemonChannel()
9193
mapOutputTracker.stop()
9294
shuffleManager.stop()
9395
broadcastManager.stop()
@@ -114,6 +116,34 @@ class SparkEnv (
114116
}
115117
}
116118

119+
private[spark] def setRDaemonChannel(daemonChannel: DataOutputStream) {
120+
rDaemonChannel = daemonChannel
121+
}
122+
123+
private[spark] def rDaemonExists(): Boolean = {
124+
rDaemonChannel != null
125+
}
126+
127+
private[spark] def destroyRDaemonChannel(): Unit = {
128+
if (rDaemonChannel != null) {
129+
rDaemonChannel.close()
130+
rDaemonChannel = null
131+
}
132+
}
133+
134+
private[spark] def createRWorkerFromDaemon(port: Int) {
135+
try {
136+
rDaemonChannel.writeInt(port)
137+
rDaemonChannel.flush()
138+
} catch {
139+
case e: IOException =>
140+
// daemon process died
141+
destroyRDaemonChannel()
142+
// fail the current task, retry by scheduler
143+
throw e
144+
}
145+
}
146+
117147
private[spark]
118148
def createPythonWorker(pythonExec: Option[String], envVars: Map[String, String],
119149
condaInstructions: Option[CondaSetupInstructions]): java.net.Socket = {

core/src/main/scala/org/apache/spark/api/conda/CondaEnvironmentManager.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ import org.json4s.jackson.Json4sScalaModule
3636
import org.json4s.jackson.JsonMethods
3737

3838
import org.apache.spark.SparkConf
39+
import org.apache.spark.SparkEnv
3940
import org.apache.spark.SparkException
41+
import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
4042
import org.apache.spark.internal.Logging
4143
import org.apache.spark.internal.config.CONDA_BINARY_PATH
4244
import org.apache.spark.internal.config.CONDA_GLOBAL_PACKAGE_DIRS
@@ -216,4 +218,23 @@ object CondaEnvironmentManager extends Logging {
216218
val packageDirs = sparkConf.get(CONDA_GLOBAL_PACKAGE_DIRS)
217219
new CondaEnvironmentManager(condaBinaryPath, verbosity, packageDirs)
218220
}
221+
222+
/**
223+
* Helper method to create a conda environment from [[CondaEnvironment.CondaSetupInstructions]].
224+
* This is intended to be called on the executor with serialized instructions.
225+
*/
226+
def createCondaEnvironment(instructions: CondaSetupInstructions): CondaEnvironment = {
227+
val condaPackages = instructions.packages
228+
val env = SparkEnv.get
229+
val condaEnvManager = CondaEnvironmentManager.fromConf(env.conf)
230+
val envDir = {
231+
// Which local dir to create it in?
232+
val localDirs = env.blockManager.diskBlockManager.localDirs
233+
val hash = Utils.nonNegativeHash(condaPackages)
234+
val dirId = hash % localDirs.length
235+
Utils.createTempDir(localDirs(dirId).getAbsolutePath, "conda").getAbsolutePath
236+
}
237+
condaEnvManager.create(envDir, condaPackages, instructions.channels)
238+
}
239+
219240
}

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import org.apache.spark._
2929
import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
3030
import org.apache.spark.api.conda.CondaEnvironmentManager
3131
import org.apache.spark.internal.Logging
32-
import org.apache.spark.internal.config.CONDA_BOOTSTRAP_PACKAGES
3332
import org.apache.spark.util.{RedirectThread, Utils}
3433

3534
private[spark] class PythonWorkerFactory(requestedPythonExec: Option[String],
@@ -62,20 +61,7 @@ private[spark] class PythonWorkerFactory(requestedPythonExec: Option[String],
6261

6362
private[this] val condaEnv = {
6463
// Set up conda environment if there are any conda packages requested
65-
condaInstructions.map { instructions =>
66-
val condaPackages = instructions.packages
67-
68-
val env = SparkEnv.get
69-
val condaEnvManager = CondaEnvironmentManager.fromConf(env.conf)
70-
val envDir = {
71-
// Which local dir to create it in?
72-
val localDirs = env.blockManager.diskBlockManager.localDirs
73-
val hash = Utils.nonNegativeHash(condaPackages)
74-
val dirId = hash % localDirs.length
75-
Utils.createTempDir(localDirs(dirId).getAbsolutePath, "conda").getAbsolutePath
76-
}
77-
condaEnvManager.create(envDir, condaPackages, instructions.channels)
78-
}
64+
condaInstructions.map(CondaEnvironmentManager.createCondaEnvironment)
7965
}
8066

8167
private[this] val envVars: Map[String, String] = {

core/src/main/scala/org/apache/spark/api/r/RRDD.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
2424
import scala.reflect.ClassTag
2525

2626
import org.apache.spark._
27+
import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
2728
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
2829
import org.apache.spark.api.python.PythonRDD
2930
import org.apache.spark.broadcast.Broadcast
@@ -39,11 +40,17 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
3940
packageNames: Array[Byte],
4041
broadcastVars: Array[Broadcast[Object]])
4142
extends RDD[U](parent) with Logging {
43+
44+
/**
45+
* Get the conda instructions eagerly - when the RDD is created.
46+
*/
47+
val condaInstructions: Option[CondaSetupInstructions] = context.buildCondaInstructions()
48+
4249
override def getPartitions: Array[Partition] = parent.partitions
4350

4451
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
4552
val runner = new RRunner[U](
46-
func, deserializer, serializer, packageNames, broadcastVars, numPartitions)
53+
func, deserializer, serializer, packageNames, broadcastVars, condaInstructions, numPartitions)
4754

4855
// The parent may be also an RRDD, so we should launch it first.
4956
val parentIterator = firstParent[T].iterator(partition, context)
@@ -79,6 +86,7 @@ private class RRDD[T: ClassTag](
7986
deserializer: String,
8087
serializer: String,
8188
packageNames: Array[Byte],
89+
condaSetupInstructions: Option[CondaSetupInstructions],
8290
broadcastVars: Array[Object])
8391
extends BaseRRDD[T, Array[Byte]](
8492
parent, -1, func, deserializer, serializer, packageNames,
@@ -94,6 +102,7 @@ private class StringRRDD[T: ClassTag](
94102
func: Array[Byte],
95103
deserializer: String,
96104
packageNames: Array[Byte],
105+
condaSetupInstructions: Option[CondaSetupInstructions],
97106
broadcastVars: Array[Object])
98107
extends BaseRRDD[T, String](
99108
parent, -1, func, deserializer, SerializationFormats.STRING, packageNames,

0 commit comments

Comments
 (0)