@@ -21,7 +21,7 @@ import java.io.{File, PrintWriter}
21
21
22
22
import scala .reflect .ClassTag
23
23
24
- import org .apache .spark .sql .{ AnalysisException , QueryTest , Row , StatisticsTest }
24
+ import org .apache .spark .sql ._
25
25
import org .apache .spark .sql .catalyst .{InternalRow , TableIdentifier }
26
26
import org .apache .spark .sql .catalyst .plans .logical .{ColumnStat , Statistics }
27
27
import org .apache .spark .sql .execution .command .{AnalyzeTableCommand , DDLUtils }
@@ -358,53 +358,189 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
358
358
}
359
359
}
360
360
361
- test(" generate column-level statistics and load them from hive metastore" ) {
361
+ private def getStatsBeforeAfterUpdate (isAnalyzeColumns : Boolean ): (Statistics , Statistics ) = {
362
+ val tableName = " tbl"
363
+ var statsBeforeUpdate : Statistics = null
364
+ var statsAfterUpdate : Statistics = null
365
+ withTable(tableName) {
366
+ val tableIndent = TableIdentifier (tableName, Some (" default" ))
367
+ val catalog = spark.sessionState.catalog.asInstanceOf [HiveSessionCatalog ]
368
+ sql(s " CREATE TABLE $tableName (key int) USING PARQUET " )
369
+ sql(s " INSERT INTO $tableName SELECT 1 " )
370
+ if (isAnalyzeColumns) {
371
+ sql(s " ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key " )
372
+ } else {
373
+ sql(s " ANALYZE TABLE $tableName COMPUTE STATISTICS " )
374
+ }
375
+ // Table lookup will make the table cached.
376
+ catalog.lookupRelation(tableIndent)
377
+ statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent)
378
+ .asInstanceOf [LogicalRelation ].catalogTable.get.stats.get
379
+
380
+ sql(s " INSERT INTO $tableName SELECT 2 " )
381
+ if (isAnalyzeColumns) {
382
+ sql(s " ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key " )
383
+ } else {
384
+ sql(s " ANALYZE TABLE $tableName COMPUTE STATISTICS " )
385
+ }
386
+ catalog.lookupRelation(tableIndent)
387
+ statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent)
388
+ .asInstanceOf [LogicalRelation ].catalogTable.get.stats.get
389
+ }
390
+ (statsBeforeUpdate, statsAfterUpdate)
391
+ }
392
+
393
+ test(" test refreshing table stats of cached data source table by `ANALYZE TABLE` statement" ) {
394
+ val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = false )
395
+
396
+ assert(statsBeforeUpdate.sizeInBytes > 0 )
397
+ assert(statsBeforeUpdate.rowCount.contains(1 ))
398
+
399
+ assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
400
+ assert(statsAfterUpdate.rowCount.contains(2 ))
401
+ }
402
+
403
+ test(" test refreshing column stats of cached data source table by `ANALYZE TABLE` statement" ) {
404
+ val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true )
405
+
406
+ assert(statsBeforeUpdate.sizeInBytes > 0 )
407
+ assert(statsBeforeUpdate.rowCount.contains(1 ))
408
+ StatisticsTest .checkColStat(
409
+ dataType = IntegerType ,
410
+ colStat = statsBeforeUpdate.colStats(" key" ),
411
+ expectedColStat = ColumnStat (InternalRow (0L , 1 , 1 , 1L )),
412
+ rsd = spark.sessionState.conf.ndvMaxError)
413
+
414
+ assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
415
+ assert(statsAfterUpdate.rowCount.contains(2 ))
416
+ StatisticsTest .checkColStat(
417
+ dataType = IntegerType ,
418
+ colStat = statsAfterUpdate.colStats(" key" ),
419
+ expectedColStat = ColumnStat (InternalRow (0L , 2 , 1 , 2L )),
420
+ rsd = spark.sessionState.conf.ndvMaxError)
421
+ }
422
+
423
+ private def dataAndColStats (): (DataFrame , Seq [(StructField , ColumnStat )]) = {
362
424
import testImplicits ._
363
425
364
426
val intSeq = Seq (1 , 2 )
365
427
val stringSeq = Seq (" a" , " bb" )
428
+ val binarySeq = Seq (" a" , " bb" ).map(_.getBytes)
366
429
val booleanSeq = Seq (true , false )
367
-
368
430
val data = intSeq.indices.map { i =>
369
- (intSeq(i), stringSeq(i), booleanSeq(i))
431
+ (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i))
370
432
}
371
- val tableName = " table"
372
- withTable(tableName) {
373
- val df = data.toDF(" c1" , " c2" , " c3" )
374
- df.write.format(" parquet" ).saveAsTable(tableName)
375
- val expectedColStatsSeq = df.schema.map { f =>
376
- val colStat = f.dataType match {
377
- case IntegerType =>
378
- ColumnStat (InternalRow (0L , intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
379
- case StringType =>
380
- ColumnStat (InternalRow (0L , stringSeq.map(_.length).sum / stringSeq.length.toDouble,
381
- stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
382
- case BooleanType =>
383
- ColumnStat (InternalRow (0L , booleanSeq.count(_.equals(true )).toLong,
384
- booleanSeq.count(_.equals(false )).toLong))
385
- }
386
- (f, colStat)
433
+ val df = data.toDF(" c1" , " c2" , " c3" , " c4" )
434
+ val expectedColStatsSeq = df.schema.map { f =>
435
+ val colStat = f.dataType match {
436
+ case IntegerType =>
437
+ ColumnStat (InternalRow (0L , intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
438
+ case StringType =>
439
+ ColumnStat (InternalRow (0L , stringSeq.map(_.length).sum / stringSeq.length.toDouble,
440
+ stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
441
+ case BinaryType =>
442
+ ColumnStat (InternalRow (0L , binarySeq.map(_.length).sum / binarySeq.length.toDouble,
443
+ binarySeq.map(_.length).max.toInt))
444
+ case BooleanType =>
445
+ ColumnStat (InternalRow (0L , booleanSeq.count(_.equals(true )).toLong,
446
+ booleanSeq.count(_.equals(false )).toLong))
387
447
}
448
+ (f, colStat)
449
+ }
450
+ (df, expectedColStatsSeq)
451
+ }
452
+
453
+ private def checkColStats (
454
+ tableName : String ,
455
+ isDataSourceTable : Boolean ,
456
+ expectedColStatsSeq : Seq [(StructField , ColumnStat )]): Unit = {
457
+ val readback = spark.table(tableName)
458
+ val stats = readback.queryExecution.analyzed.collect {
459
+ case rel : MetastoreRelation =>
460
+ assert(! isDataSourceTable, " Expected a Hive serde table, but got a data source table" )
461
+ rel.catalogTable.stats.get
462
+ case rel : LogicalRelation =>
463
+ assert(isDataSourceTable, " Expected a data source table, but got a Hive serde table" )
464
+ rel.catalogTable.get.stats.get
465
+ }
466
+ assert(stats.length == 1 )
467
+ val columnStats = stats.head.colStats
468
+ assert(columnStats.size == expectedColStatsSeq.length)
469
+ expectedColStatsSeq.foreach { case (field, expectedColStat) =>
470
+ StatisticsTest .checkColStat(
471
+ dataType = field.dataType,
472
+ colStat = columnStats(field.name),
473
+ expectedColStat = expectedColStat,
474
+ rsd = spark.sessionState.conf.ndvMaxError)
475
+ }
476
+ }
388
477
389
- sql(s " ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3 " )
390
- val readback = spark.table(tableName)
391
- val relations = readback.queryExecution.analyzed.collect { case rel : LogicalRelation =>
392
- val columnStats = rel.catalogTable.get.stats.get.colStats
393
- expectedColStatsSeq.foreach { case (field, expectedColStat) =>
394
- assert(columnStats.contains(field.name))
395
- val colStat = columnStats(field.name)
478
+ test(" generate and load column-level stats for data source table" ) {
479
+ val dsTable = " dsTable"
480
+ withTable(dsTable) {
481
+ val (df, expectedColStatsSeq) = dataAndColStats()
482
+ df.write.format(" parquet" ).saveAsTable(dsTable)
483
+ sql(s " ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4 " )
484
+ checkColStats(dsTable, isDataSourceTable = true , expectedColStatsSeq)
485
+ }
486
+ }
487
+
488
+ test(" generate and load column-level stats for hive serde table" ) {
489
+ val hTable = " hTable"
490
+ val tmp = " tmp"
491
+ withTable(hTable, tmp) {
492
+ val (df, expectedColStatsSeq) = dataAndColStats()
493
+ df.write.format(" parquet" ).saveAsTable(tmp)
494
+ sql(s " CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE " )
495
+ sql(s " INSERT INTO $hTable SELECT * FROM $tmp" )
496
+ sql(s " ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4 " )
497
+ checkColStats(hTable, isDataSourceTable = false , expectedColStatsSeq)
498
+ }
499
+ }
500
+
501
+ // When caseSensitive is on, for columns with only case difference, they are different columns
502
+ // and we should generate column stats for all of them.
503
+ private def checkCaseSensitiveColStats (columnName : String ): Unit = {
504
+ val tableName = " tbl"
505
+ withTable(tableName) {
506
+ val column1 = columnName.toLowerCase
507
+ val column2 = columnName.toUpperCase
508
+ withSQLConf(" spark.sql.caseSensitive" -> " true" ) {
509
+ sql(s " CREATE TABLE $tableName (` $column1` int, ` $column2` double) USING PARQUET " )
510
+ sql(s " INSERT INTO $tableName SELECT 1, 3.0 " )
511
+ sql(s " ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS ` $column1`, ` $column2` " )
512
+ val readback = spark.table(tableName)
513
+ val relations = readback.queryExecution.analyzed.collect { case rel : LogicalRelation =>
514
+ val columnStats = rel.catalogTable.get.stats.get.colStats
515
+ assert(columnStats.size == 2 )
516
+ StatisticsTest .checkColStat(
517
+ dataType = IntegerType ,
518
+ colStat = columnStats(column1),
519
+ expectedColStat = ColumnStat (InternalRow (0L , 1 , 1 , 1L )),
520
+ rsd = spark.sessionState.conf.ndvMaxError)
396
521
StatisticsTest .checkColStat(
397
- dataType = field.dataType ,
398
- colStat = colStat ,
399
- expectedColStat = expectedColStat ,
522
+ dataType = DoubleType ,
523
+ colStat = columnStats(column2) ,
524
+ expectedColStat = ColumnStat ( InternalRow ( 0L , 3.0d , 3.0d , 1L )) ,
400
525
rsd = spark.sessionState.conf.ndvMaxError)
526
+ rel
401
527
}
402
- rel
528
+ assert(relations.size == 1 )
403
529
}
404
- assert(relations.size == 1 )
405
530
}
406
531
}
407
532
533
+ test(" check column statistics for case sensitive column names" ) {
534
+ checkCaseSensitiveColStats(columnName = " c1" )
535
+ }
536
+
537
+ test(" check column statistics for case sensitive non-ascii column names" ) {
538
+ // scalastyle:off
539
+ // non ascii characters are not allowed in the source code, so we disable the scalastyle.
540
+ checkCaseSensitiveColStats(columnName = " 列c" )
541
+ // scalastyle:on
542
+ }
543
+
408
544
test(" estimates the size of a test MetastoreRelation" ) {
409
545
val df = sql(""" SELECT * FROM src""" )
410
546
val sizes = df.queryExecution.analyzed.collect { case mr : MetastoreRelation =>
0 commit comments