@@ -23,6 +23,13 @@ import java.util.concurrent.ConcurrentLinkedQueue
23
23
24
24
import scala .collection .JavaConverters ._
25
25
26
+ import org .apache .hadoop .conf .Configuration
27
+ import org .apache .hadoop .fs .Path
28
+ import org .apache .parquet .hadoop .ParquetFileReader
29
+ import org .apache .parquet .hadoop .util .HadoopInputFile
30
+ import org .apache .parquet .schema .PrimitiveType
31
+ import org .apache .parquet .schema .PrimitiveType .PrimitiveTypeName
32
+ import org .apache .parquet .schema .Type .Repetition
26
33
import org .scalatest .BeforeAndAfter
27
34
28
35
import org .apache .spark .SparkContext
@@ -31,6 +38,7 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
31
38
import org .apache .spark .scheduler .{SparkListener , SparkListenerJobStart }
32
39
import org .apache .spark .sql ._
33
40
import org .apache .spark .sql .catalyst .TableIdentifier
41
+ import org .apache .spark .sql .execution .datasources .parquet .SpecificParquetRecordReaderBase
34
42
import org .apache .spark .sql .internal .SQLConf
35
43
import org .apache .spark .sql .sources ._
36
44
import org .apache .spark .sql .types ._
@@ -522,11 +530,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
522
530
Seq (" json" , " orc" , " parquet" , " csv" ).foreach { format =>
523
531
val schema = StructType (
524
532
StructField (" cl1" , IntegerType , nullable = false ).withComment(" test" ) ::
525
- StructField (" cl2" , IntegerType , nullable = true ) ::
526
- StructField (" cl3" , IntegerType , nullable = true ) :: Nil )
533
+ StructField (" cl2" , IntegerType , nullable = true ) ::
534
+ StructField (" cl3" , IntegerType , nullable = true ) :: Nil )
527
535
val row = Row (3 , null , 4 )
528
536
val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil ), schema)
529
537
538
+ // if we write and then read, the read will enforce schema to be nullable
530
539
val tableName = " tab"
531
540
withTable(tableName) {
532
541
df.write.format(format).mode(" overwrite" ).saveAsTable(tableName)
@@ -536,12 +545,41 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
536
545
Row (" cl1" , " test" ) :: Nil )
537
546
// Verify the schema
538
547
val expectedFields = schema.fields.map(f => f.copy(nullable = true ))
539
- assert(spark.table(tableName).schema == schema.copy(fields = expectedFields))
548
+ assert(spark.table(tableName).schema === schema.copy(fields = expectedFields))
540
549
}
541
550
}
542
551
}
543
552
}
544
553
554
+ test(" parquet - column nullability -- write only" ) {
555
+ val schema = StructType (
556
+ StructField (" cl1" , IntegerType , nullable = false ) ::
557
+ StructField (" cl2" , IntegerType , nullable = true ) :: Nil )
558
+ val row = Row (3 , 4 )
559
+ val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil ), schema)
560
+
561
+ withTempPath { dir =>
562
+ val path = dir.getAbsolutePath
563
+ df.write.mode(" overwrite" ).parquet(path)
564
+ val file = SpecificParquetRecordReaderBase .listDirectory(dir).get(0 )
565
+
566
+ val hadoopInputFile = HadoopInputFile .fromPath(new Path (file), new Configuration ())
567
+ val f = ParquetFileReader .open(hadoopInputFile)
568
+ val parquetSchema = f.getFileMetaData.getSchema.getColumns.asScala
569
+ .map(_.getPrimitiveType)
570
+ f.close()
571
+
572
+ // the write keeps nullable info from the schema
573
+ val expectedParquetSchema = Seq (
574
+ new PrimitiveType (Repetition .REQUIRED , PrimitiveTypeName .INT32 , " cl1" ),
575
+ new PrimitiveType (Repetition .OPTIONAL , PrimitiveTypeName .INT32 , " cl2" )
576
+ )
577
+
578
+ assert (expectedParquetSchema === parquetSchema)
579
+ }
580
+
581
+ }
582
+
545
583
test(" SPARK-17230: write out results of decimal calculation" ) {
546
584
val df = spark.range(99 , 101 )
547
585
.selectExpr(" id" , " cast(id as long) * cast('1.0' as decimal(38, 18)) as num" )
0 commit comments