-
Notifications
You must be signed in to change notification settings - Fork 28.6k
SPARK-4687. Add a recursive option to the addFile API #3670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
46fe70a
0239c3d
31f15a9
1941be3
ca83849
38bf94d
13da824
70cd24d
f9fc77f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,29 +24,37 @@ import java.net.URI | |
import java.util.{Arrays, Properties, UUID} | ||
import java.util.concurrent.atomic.AtomicInteger | ||
import java.util.UUID.randomUUID | ||
|
||
import scala.collection.{Map, Set} | ||
import scala.collection.JavaConversions._ | ||
import scala.collection.generic.Growable | ||
import scala.collection.mutable.HashMap | ||
import scala.reflect.{ClassTag, classTag} | ||
|
||
import akka.actor.Props | ||
|
||
import org.apache.hadoop.conf.Configuration | ||
import org.apache.hadoop.fs.Path | ||
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} | ||
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} | ||
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, | ||
FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} | ||
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, | ||
TextInputFormat} | ||
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} | ||
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove this line |
||
import org.apache.mesos.MesosNativeLibrary | ||
import akka.actor.Props | ||
|
||
import org.apache.spark.annotation.{DeveloperApi, Experimental} | ||
import org.apache.spark.broadcast.Broadcast | ||
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} | ||
import org.apache.spark.executor.TriggerThreadDump | ||
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} | ||
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, | ||
FixedLengthBinaryInputFormat} | ||
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} | ||
import org.apache.spark.rdd._ | ||
import org.apache.spark.scheduler._ | ||
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} | ||
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, | ||
SparkDeploySchedulerBackend, SimrSchedulerBackend} | ||
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} | ||
import org.apache.spark.scheduler.local.LocalBackend | ||
import org.apache.spark.storage._ | ||
|
@@ -996,12 +1004,48 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli | |
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, | ||
* use `SparkFiles.get(fileName)` to find its download location. | ||
*/ | ||
def addFile(path: String) { | ||
def addFile(path: String): Unit = { | ||
addFile(path, false) | ||
} | ||
|
||
/** | ||
* Add a file to be downloaded with this Spark job on every node. | ||
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported | ||
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, | ||
* use `SparkFiles.get(fileName)` to find its download location. | ||
* | ||
* A directory can be given if the recursive option is set to true. Currently directories are only | ||
* supported for Hadoop-supported filesystems. | ||
*/ | ||
def addFile(path: String, recursive: Boolean): Unit = { | ||
val uri = new URI(path) | ||
val key = uri.getScheme match { | ||
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) | ||
case "local" => "file:" + uri.getPath | ||
case _ => path | ||
val schemeCorrectedPath = uri.getScheme match { | ||
case null | "local" => "file:" + uri.getPath | ||
case _ => path | ||
} | ||
|
||
val hadoopPath = new Path(schemeCorrectedPath) | ||
val scheme = new URI(schemeCorrectedPath).getScheme | ||
if (!Array("http", "https", "ftp").contains(scheme)) { | ||
val fs = hadoopPath.getFileSystem(hadoopConfiguration) | ||
if (!fs.exists(hadoopPath)) { | ||
throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") | ||
} | ||
val isDir = fs.isDirectory(hadoopPath) | ||
if (!isLocal && scheme == "file" && isDir) { | ||
throw new SparkException(s"addFile does not support local directories when not running " + | ||
"local mode.") | ||
} | ||
if (!recursive && isDir) { | ||
throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + | ||
"turned on.") | ||
} | ||
} | ||
|
||
val key = if (!isLocal && scheme == "file") { | ||
env.httpFileServer.addFile(new File(uri.getPath)) | ||
} else { | ||
schemeCorrectedPath | ||
} | ||
val timestamp = System.currentTimeMillis | ||
addedFiles(key) = timestamp | ||
|
@@ -1549,8 +1593,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli | |
val schedulingMode = getSchedulingMode.toString | ||
val addedJarPaths = addedJars.keys.toSeq | ||
val addedFilePaths = addedFiles.keys.toSeq | ||
val environmentDetails = | ||
SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths) | ||
val environmentDetails = SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, | ||
addedFilePaths) | ||
val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails) | ||
listenerBus.post(environmentUpdate) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -359,8 +359,10 @@ private[spark] object Utils extends Logging { | |
} | ||
|
||
/** | ||
* Download a file to target directory. Supports fetching the file in a variety of ways, | ||
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. | ||
* Download a file or directory to target directory. Supports fetching the file in a variety of | ||
* ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based | ||
* on the URL parameter. Fetching directories is only supported from Hadoop-compatible | ||
* filesystems. | ||
* | ||
* If `useCache` is true, first attempts to fetch the file to a local cache that's shared | ||
* across executors running the same application. `useCache` is used mainly for | ||
|
@@ -429,17 +431,18 @@ private[spark] object Utils extends Logging { | |
* | ||
* @param url URL that `sourceFile` originated from, for logging purposes. | ||
* @param in InputStream to download. | ||
* @param tempFile File path to download `in` to. | ||
* @param destFile File path to move `tempFile` to. | ||
* @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match | ||
* `sourceFile` | ||
*/ | ||
private def downloadFile( | ||
url: String, | ||
in: InputStream, | ||
tempFile: File, | ||
destFile: File, | ||
fileOverwrite: Boolean): Unit = { | ||
val tempFile = File.createTempFile("fetchFileTemp", null, | ||
new File(destFile.getParentFile.getAbsolutePath)) | ||
logInfo(s"Fetching $url to $tempFile") | ||
|
||
try { | ||
val out = new FileOutputStream(tempFile) | ||
|
@@ -478,7 +481,7 @@ private[spark] object Utils extends Logging { | |
removeSourceFile: Boolean = false): Unit = { | ||
|
||
if (destFile.exists) { | ||
if (!Files.equal(sourceFile, destFile)) { | ||
if (!filesEqualRecursive(sourceFile, destFile)) { | ||
if (fileOverwrite) { | ||
logInfo( | ||
s"File $destFile exists and does not match contents of $url, replacing it with $url" | ||
|
@@ -513,13 +516,44 @@ private[spark] object Utils extends Logging { | |
Files.move(sourceFile, destFile) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear from the docs, but does this work for directories? Could you add a unit test? (It doesn't seem like this is covered for directories by the current ones, after a quick look.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tested this in isolation and it does work for directories. A new test that I added should exercise this path. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm...
I think it doesn't work if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We talked about this offline; it seems that, in Spark's case, both source and destination are on the same filesystem, so this should work. Guava does something akin to So if we care about Windows here we should probably do a |
||
} else { | ||
logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") | ||
Files.copy(sourceFile, destFile) | ||
copyRecursive(sourceFile, destFile) | ||
} | ||
} | ||
|
||
private def filesEqualRecursive(file1: File, file2: File): Boolean = { | ||
if (file1.isDirectory && file2.isDirectory) { | ||
val subfiles1 = file1.listFiles() | ||
val subfiles2 = file2.listFiles() | ||
if (subfiles1.size != subfiles2.size) { | ||
return false | ||
} | ||
subfiles1.sortBy(_.getName).zip(subfiles2.sortBy(_.getName)).forall { | ||
case (f1, f2) => filesEqualRecursive(f1, f2) | ||
} | ||
} else if (file1.isFile && file2.isFile) { | ||
Files.equal(file1, file2) | ||
} else { | ||
false | ||
} | ||
} | ||
|
||
private def copyRecursive(source: File, dest: File): Unit = { | ||
if (source.isDirectory) { | ||
if (!dest.mkdir()) { | ||
throw new IOException(s"Failed to create directory ${dest.getPath}") | ||
} | ||
val subfiles = source.listFiles() | ||
subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName))) | ||
} else { | ||
Files.copy(source, dest) | ||
} | ||
} | ||
|
||
/** | ||
* Download a file to target directory. Supports fetching the file in a variety of ways, | ||
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. | ||
* Download a file or directory to target directory. Supports fetching the file in a variety of | ||
* ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based | ||
* on the URL parameter. Fetching directories is only supported from Hadoop-compatible | ||
* filesystems. | ||
* | ||
* Throws SparkException if the target file already exists and has different contents than | ||
* the requested file. | ||
|
@@ -531,14 +565,11 @@ private[spark] object Utils extends Logging { | |
conf: SparkConf, | ||
securityMgr: SecurityManager, | ||
hadoopConf: Configuration) { | ||
val tempFile = File.createTempFile("fetchFileTemp", null, new File(targetDir.getAbsolutePath)) | ||
val targetFile = new File(targetDir, filename) | ||
val uri = new URI(url) | ||
val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) | ||
Option(uri.getScheme).getOrElse("file") match { | ||
case "http" | "https" | "ftp" => | ||
logInfo("Fetching " + url + " to " + tempFile) | ||
|
||
var uc: URLConnection = null | ||
if (securityMgr.isAuthenticationEnabled()) { | ||
logDebug("fetchFile with security enabled") | ||
|
@@ -555,17 +586,44 @@ private[spark] object Utils extends Logging { | |
uc.setReadTimeout(timeout) | ||
uc.connect() | ||
val in = uc.getInputStream() | ||
downloadFile(url, in, tempFile, targetFile, fileOverwrite) | ||
downloadFile(url, in, targetFile, fileOverwrite) | ||
case "file" => | ||
// In the case of a local file, copy the local file to the target directory. | ||
// Note the difference between uri vs url. | ||
val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) | ||
copyFile(url, sourceFile, targetFile, fileOverwrite) | ||
case _ => | ||
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others | ||
val fs = getHadoopFileSystem(uri, hadoopConf) | ||
val in = fs.open(new Path(uri)) | ||
downloadFile(url, in, tempFile, targetFile, fileOverwrite) | ||
val path = new Path(uri) | ||
fetchHcfsFile(path, new File(targetDir, path.getName), fs, conf, hadoopConf, fileOverwrite) | ||
} | ||
} | ||
|
||
/** | ||
* Fetch a file or directory from a Hadoop-compatible filesystem. | ||
* | ||
* Visible for testing | ||
*/ | ||
private[spark] def fetchHcfsFile( | ||
path: Path, | ||
targetDir: File, | ||
fs: FileSystem, | ||
conf: SparkConf, | ||
hadoopConf: Configuration, | ||
fileOverwrite: Boolean): Unit = { | ||
if (!targetDir.mkdir()) { | ||
throw new IOException(s"Failed to create directory ${targetDir.getPath}") | ||
} | ||
fs.listStatus(path).foreach { fileStatus => | ||
val innerPath = fileStatus.getPath | ||
if (fileStatus.isDir) { | ||
fetchHcfsFile(innerPath, new File(targetDir, innerPath.getName), fs, conf, hadoopConf, | ||
fileOverwrite) | ||
} else { | ||
val in = fs.open(innerPath) | ||
val targetFile = new File(targetDir, innerPath.getName) | ||
downloadFile(innerPath.toString, in, targetFile, fileOverwrite) | ||
} | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,17 @@ | |
|
||
package org.apache.spark | ||
|
||
import java.io.File | ||
|
||
import com.google.common.base.Charsets._ | ||
import com.google.common.io.Files | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.hadoop.io.BytesWritable | ||
|
||
import org.apache.spark.util.Utils | ||
|
||
class SparkContextSuite extends FunSuite with LocalSparkContext { | ||
|
||
test("Only one SparkContext may be active at a time") { | ||
|
@@ -72,4 +79,74 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { | |
val byteArray2 = converter.convert(bytesWritable) | ||
assert(byteArray2.length === 0) | ||
} | ||
|
||
test("addFile works") { | ||
val file = File.createTempFile("someprefix", "somesuffix") | ||
val absolutePath = file.getAbsolutePath | ||
try { | ||
Files.write("somewords", file, UTF_8) | ||
val length = file.length() | ||
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) | ||
sc.addFile(file.getAbsolutePath) | ||
sc.parallelize(Array(1), 1).map(x => { | ||
val gotten = new File(SparkFiles.get(file.getName)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK you can serialize scala> import java.io.File
import java.io.File
scala> import org.apache.commons.lang3.SerializationUtils
import org.apache.commons.lang3.SerializationUtils
scala> SerializationUtils.serialize(new File("/usr/share/dict/words"))
res1: Array[Byte] = Array(-84, -19, 0, 5, 115, 114, 0, 12, 106, 97, 118, 97, 46, 105, 111, 46, 70, 105, 108, 101, 4, 45, -92, 69, 14, 13, -28, -1, 3, 0, 1, 76, 0, 4, 112, 97, 116, 104, 116, 0, 18, 76, 106, 97, 118, 97, 47, 108, 97, 110, 103, 47, 83, 116, 114, 105, 110, 103, 59, 120, 112, 116, 0, 21, 47, 117, 115, 114, 47, 115, 104, 97, 114, 101, 47, 100, 105, 99, 116, 47, 119, 111, 114, 100, 115, 119, 2, 0, 47, 120) |
||
if (!gotten.exists()) { | ||
throw new SparkException("file doesn't exist") | ||
} | ||
if (length != gotten.length()) { | ||
throw new SparkException( | ||
s"file has different length $length than added file ${gotten.length()}") | ||
} | ||
if (absolutePath == gotten.getAbsolutePath) { | ||
throw new SparkException("file should have been copied") | ||
} | ||
x | ||
}).count() | ||
} finally { | ||
sc.stop() | ||
} | ||
} | ||
|
||
test("addFile recursive works") { | ||
val pluto = Utils.createTempDir() | ||
val neptune = Utils.createTempDir(pluto.getAbsolutePath) | ||
val saturn = Utils.createTempDir(neptune.getAbsolutePath) | ||
val alien1 = File.createTempFile("alien", "1", neptune) | ||
val alien2 = File.createTempFile("alien", "2", saturn) | ||
|
||
try { | ||
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) | ||
sc.addFile(neptune.getAbsolutePath, true) | ||
sc.parallelize(Array(1), 1).map(x => { | ||
val sep = File.separator | ||
if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) { | ||
throw new SparkException("can't access file under root added directory") | ||
} | ||
if (!new File(SparkFiles.get(neptune.getName + sep + saturn.getName + sep + alien2.getName)) | ||
.exists()) { | ||
throw new SparkException("can't access file in nested directory") | ||
} | ||
if (new File(SparkFiles.get(pluto.getName + sep + neptune.getName + sep + alien1.getName)) | ||
.exists()) { | ||
throw new SparkException("file exists that shouldn't") | ||
} | ||
x | ||
}).count() | ||
} finally { | ||
sc.stop() | ||
} | ||
} | ||
|
||
test("addFile recursive can't add directories by default") { | ||
val dir = Utils.createTempDir() | ||
|
||
try { | ||
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) | ||
intercept[SparkException] { | ||
sc.addFile(dir.getAbsolutePath) | ||
} | ||
} finally { | ||
sc.stop() | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove this line