|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.sql.avro |
| 19 | + |
| 20 | +import java.io._ |
| 21 | +import java.net.URI |
| 22 | +import java.util.zip.Deflater |
| 23 | + |
| 24 | +import scala.util.control.NonFatal |
| 25 | + |
| 26 | +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} |
| 27 | +import com.esotericsoftware.kryo.io.{Input, Output} |
| 28 | +import org.apache.avro.{Schema, SchemaBuilder} |
| 29 | +import org.apache.avro.file.{DataFileConstants, DataFileReader} |
| 30 | +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} |
| 31 | +import org.apache.avro.mapred.{AvroOutputFormat, FsInput} |
| 32 | +import org.apache.avro.mapreduce.AvroJob |
| 33 | +import org.apache.hadoop.conf.Configuration |
| 34 | +import org.apache.hadoop.fs.{FileStatus, Path} |
| 35 | +import org.apache.hadoop.mapreduce.Job |
| 36 | +import org.slf4j.LoggerFactory |
| 37 | + |
| 38 | +import org.apache.spark.TaskContext |
| 39 | +import org.apache.spark.sql.SparkSession |
| 40 | +import org.apache.spark.sql.catalyst.InternalRow |
| 41 | +import org.apache.spark.sql.catalyst.encoders.RowEncoder |
| 42 | +import org.apache.spark.sql.catalyst.expressions.GenericRow |
| 43 | +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} |
| 44 | +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} |
| 45 | +import org.apache.spark.sql.types.StructType |
| 46 | + |
| 47 | +private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { |
| 48 | + private val log = LoggerFactory.getLogger(getClass) |
| 49 | + |
| 50 | + override def equals(other: Any): Boolean = other match { |
| 51 | + case _: AvroFileFormat => true |
| 52 | + case _ => false |
| 53 | + } |
| 54 | + |
| 55 | + // Dummy hashCode() to appease ScalaStyle. |
| 56 | + override def hashCode(): Int = super.hashCode() |
| 57 | + |
| 58 | + override def inferSchema( |
| 59 | + spark: SparkSession, |
| 60 | + options: Map[String, String], |
| 61 | + files: Seq[FileStatus]): Option[StructType] = { |
| 62 | + val conf = spark.sparkContext.hadoopConfiguration |
| 63 | + |
| 64 | + // Schema evolution is not supported yet. Here we only pick a single random sample file to |
| 65 | + // figure out the schema of the whole dataset. |
| 66 | + val sampleFile = |
| 67 | + if (conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true)) { |
| 68 | + files.find(_.getPath.getName.endsWith(".avro")).getOrElse { |
| 69 | + throw new FileNotFoundException( |
| 70 | + "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + |
| 71 | + " is set to true. Do all input files have \".avro\" extension?" |
| 72 | + ) |
| 73 | + } |
| 74 | + } else { |
| 75 | + files.headOption.getOrElse { |
| 76 | + throw new FileNotFoundException("No Avro files found.") |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + // User can specify an optional avro json schema. |
| 81 | + val avroSchema = options.get(AvroFileFormat.AvroSchema) |
| 82 | + .map(new Schema.Parser().parse) |
| 83 | + .getOrElse { |
| 84 | + val in = new FsInput(sampleFile.getPath, conf) |
| 85 | + try { |
| 86 | + val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) |
| 87 | + try { |
| 88 | + reader.getSchema |
| 89 | + } finally { |
| 90 | + reader.close() |
| 91 | + } |
| 92 | + } finally { |
| 93 | + in.close() |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + SchemaConverters.toSqlType(avroSchema).dataType match { |
| 98 | + case t: StructType => Some(t) |
| 99 | + case _ => throw new RuntimeException( |
| 100 | + s"""Avro schema cannot be converted to a Spark SQL StructType: |
| 101 | + | |
| 102 | + |${avroSchema.toString(true)} |
| 103 | + |""".stripMargin) |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + override def shortName(): String = "avro" |
| 108 | + |
| 109 | + override def isSplitable( |
| 110 | + sparkSession: SparkSession, |
| 111 | + options: Map[String, String], |
| 112 | + path: Path): Boolean = true |
| 113 | + |
| 114 | + override def prepareWrite( |
| 115 | + spark: SparkSession, |
| 116 | + job: Job, |
| 117 | + options: Map[String, String], |
| 118 | + dataSchema: StructType): OutputWriterFactory = { |
| 119 | + val recordName = options.getOrElse("recordName", "topLevelRecord") |
| 120 | + val recordNamespace = options.getOrElse("recordNamespace", "") |
| 121 | + val build = SchemaBuilder.record(recordName).namespace(recordNamespace) |
| 122 | + val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace) |
| 123 | + |
| 124 | + AvroJob.setOutputKeySchema(job, outputAvroSchema) |
| 125 | + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" |
| 126 | + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" |
| 127 | + val COMPRESS_KEY = "mapred.output.compress" |
| 128 | + |
| 129 | + spark.conf.get(AVRO_COMPRESSION_CODEC, "snappy") match { |
| 130 | + case "uncompressed" => |
| 131 | + log.info("writing uncompressed Avro records") |
| 132 | + job.getConfiguration.setBoolean(COMPRESS_KEY, false) |
| 133 | + |
| 134 | + case "snappy" => |
| 135 | + log.info("compressing Avro output using Snappy") |
| 136 | + job.getConfiguration.setBoolean(COMPRESS_KEY, true) |
| 137 | + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) |
| 138 | + |
| 139 | + case "deflate" => |
| 140 | + val deflateLevel = spark.conf.get( |
| 141 | + AVRO_DEFLATE_LEVEL, Deflater.DEFAULT_COMPRESSION.toString).toInt |
| 142 | + log.info(s"compressing Avro output using deflate (level=$deflateLevel)") |
| 143 | + job.getConfiguration.setBoolean(COMPRESS_KEY, true) |
| 144 | + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) |
| 145 | + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) |
| 146 | + |
| 147 | + case unknown: String => |
| 148 | + log.error(s"unsupported compression codec $unknown") |
| 149 | + } |
| 150 | + |
| 151 | + new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace) |
| 152 | + } |
| 153 | + |
| 154 | + override def buildReader( |
| 155 | + spark: SparkSession, |
| 156 | + dataSchema: StructType, |
| 157 | + partitionSchema: StructType, |
| 158 | + requiredSchema: StructType, |
| 159 | + filters: Seq[Filter], |
| 160 | + options: Map[String, String], |
| 161 | + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { |
| 162 | + |
| 163 | + val broadcastedConf = |
| 164 | + spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) |
| 165 | + |
| 166 | + (file: PartitionedFile) => { |
| 167 | + val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) |
| 168 | + val conf = broadcastedConf.value.value |
| 169 | + val userProvidedSchema = options.get(AvroFileFormat.AvroSchema).map(new Schema.Parser().parse) |
| 170 | + |
| 171 | + // TODO Removes this check once `FileFormat` gets a general file filtering interface method. |
| 172 | + // Doing input file filtering is improper because we may generate empty tasks that process no |
| 173 | + // input files but stress the scheduler. We should probably add a more general input file |
| 174 | + // filtering mechanism for `FileFormat` data sources. See SPARK-16317. |
| 175 | + if ( |
| 176 | + conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true) && |
| 177 | + !file.filePath.endsWith(".avro") |
| 178 | + ) { |
| 179 | + Iterator.empty |
| 180 | + } else { |
| 181 | + val reader = { |
| 182 | + val in = new FsInput(new Path(new URI(file.filePath)), conf) |
| 183 | + try { |
| 184 | + val datumReader = userProvidedSchema match { |
| 185 | + case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) |
| 186 | + case _ => new GenericDatumReader[GenericRecord]() |
| 187 | + } |
| 188 | + DataFileReader.openReader(in, datumReader) |
| 189 | + } catch { |
| 190 | + case NonFatal(e) => |
| 191 | + log.error("Exception while opening DataFileReader", e) |
| 192 | + in.close() |
| 193 | + throw e |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + // Ensure that the reader is closed even if the task fails or doesn't consume the entire |
| 198 | + // iterator of records. |
| 199 | + Option(TaskContext.get()).foreach { taskContext => |
| 200 | + taskContext.addTaskCompletionListener { _ => |
| 201 | + reader.close() |
| 202 | + } |
| 203 | + } |
| 204 | + |
| 205 | + reader.sync(file.start) |
| 206 | + val stop = file.start + file.length |
| 207 | + |
| 208 | + val rowConverter = SchemaConverters.createConverterToSQL( |
| 209 | + userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) |
| 210 | + |
| 211 | + new Iterator[InternalRow] { |
| 212 | + // Used to convert `Row`s containing data columns into `InternalRow`s. |
| 213 | + private val encoderForDataColumns = RowEncoder(requiredSchema) |
| 214 | + |
| 215 | + private[this] var completed = false |
| 216 | + |
| 217 | + override def hasNext: Boolean = { |
| 218 | + if (completed) { |
| 219 | + false |
| 220 | + } else { |
| 221 | + val r = reader.hasNext && !reader.pastSync(stop) |
| 222 | + if (!r) { |
| 223 | + reader.close() |
| 224 | + completed = true |
| 225 | + } |
| 226 | + r |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | + override def next(): InternalRow = { |
| 231 | + if (reader.pastSync(stop)) { |
| 232 | + throw new NoSuchElementException("next on empty iterator") |
| 233 | + } |
| 234 | + val record = reader.next() |
| 235 | + val safeDataRow = rowConverter(record).asInstanceOf[GenericRow] |
| 236 | + |
| 237 | + // The safeDataRow is reused, we must do a copy |
| 238 | + encoderForDataColumns.toRow(safeDataRow) |
| 239 | + } |
| 240 | + } |
| 241 | + } |
| 242 | + } |
| 243 | + } |
| 244 | +} |
| 245 | + |
| 246 | +private[avro] object AvroFileFormat { |
| 247 | + val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" |
| 248 | + |
| 249 | + val AvroSchema = "avroSchema" |
| 250 | + |
| 251 | + class SerializableConfiguration(@transient var value: Configuration) |
| 252 | + extends Serializable with KryoSerializable { |
| 253 | + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) |
| 254 | + |
| 255 | + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { |
| 256 | + out.defaultWriteObject() |
| 257 | + value.write(out) |
| 258 | + } |
| 259 | + |
| 260 | + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { |
| 261 | + value = new Configuration(false) |
| 262 | + value.readFields(in) |
| 263 | + } |
| 264 | + |
| 265 | + private def tryOrIOException[T](block: => T): T = { |
| 266 | + try { |
| 267 | + block |
| 268 | + } catch { |
| 269 | + case e: IOException => |
| 270 | + log.error("Exception encountered", e) |
| 271 | + throw e |
| 272 | + case NonFatal(e) => |
| 273 | + log.error("Exception encountered", e) |
| 274 | + throw new IOException(e) |
| 275 | + } |
| 276 | + } |
| 277 | + |
| 278 | + def write(kryo: Kryo, out: Output): Unit = { |
| 279 | + val dos = new DataOutputStream(out) |
| 280 | + value.write(dos) |
| 281 | + dos.flush() |
| 282 | + } |
| 283 | + |
| 284 | + def read(kryo: Kryo, in: Input): Unit = { |
| 285 | + value = new Configuration(false) |
| 286 | + value.readFields(new DataInputStream(in)) |
| 287 | + } |
| 288 | + } |
| 289 | +} |
0 commit comments