Skip to content

[SPARK-12882][SQL] simplify bucket tests and add more comments #10813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLC
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.datasources.BucketSpec
import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
Expand Down Expand Up @@ -61,15 +62,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")

/**
* A helper method to test the bucket read functionality using join. It will save `df1` and `df2`
* to hive tables, bucketed or not, according to the given bucket specifics. Next we will join
* these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle
* exists as user expected according to the `shuffleLeft` and `shuffleRight`.
*/
private def testBucketing(
bucketing1: DataFrameWriter => DataFrameWriter,
bucketing2: DataFrameWriter => DataFrameWriter,
bucketSpecLeft: Option[BucketSpec],
bucketSpecRight: Option[BucketSpec],
joinColumns: Seq[String],
shuffleLeft: Boolean,
shuffleRight: Boolean): Unit = {
withTable("bucketed_table1", "bucketed_table2") {
bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1")
bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2")
def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
bucketSpec.map { spec =>
writer.bucketBy(
spec.numBuckets,
spec.bucketColumnNames.head,
spec.bucketColumnNames.tail: _*)
}.getOrElse(writer)
}

withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")

withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
val t1 = hiveContext.table("bucketed_table1")
Expand All @@ -95,42 +111,42 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}

test("avoid shuffle when join 2 bucketed tables") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
}

// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
ignore("avoid shuffle when join keys are a super-set of bucket keys") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
}

test("only shuffle one side when join bucketed table and non-bucketed table") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
}

test("only shuffle one side when 2 bucketed tables have different bucket number") {
val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j")
testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
}

test("only shuffle one side when 2 bucketed tables have different bucket keys") {
val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j")
testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true)
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
}

test("shuffle when join keys are not equal to bucket keys") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true)
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
}

test("shuffle when join 2 bucketed tables with bucketing disabled") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,39 +65,55 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle

private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")

/**
* A helper method to check the bucket write functionality in low level, i.e. check the written
* bucket files to see if the data are correct. User should pass in a data dir that these bucket
* files are written to, and the format of data(parquet, json, etc.), and the bucketing
* information.
*/
private def testBucketing(
dataDir: File,
source: String,
numBuckets: Int,
bucketCols: Seq[String],
sortCols: Seq[String] = Nil): Unit = {
val allBucketFiles = dataDir.listFiles().filterNot(f =>
f.getName.startsWith(".") || f.getName.startsWith("_")
)
val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get)
assert(groupedBucketFiles.size <= 8)

for ((bucketId, bucketFiles) <- groupedBucketFiles) {
for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) {
val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
val columns = (bucketCols ++ sortCols).zip(types).map {
case (colName, dt) => col(colName).cast(dt)
}
val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*)

if (sortCols.nonEmpty) {
checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
}
for (bucketFile <- allBucketFiles) {
val bucketId = BucketingUtils.getBucketId(bucketFile.getName).get
assert(bucketId >= 0 && bucketId < numBuckets)

val qe = readBack.select(bucketCols.map(col): _*).queryExecution
val rows = qe.toRdd.map(_.copy()).collect()
val getBucketId = UnsafeProjection.create(
HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil,
qe.analyzed.output)
// We may loss the type information after write(e.g. json format doesn't keep schema
// information), here we get the types from the original dataframe.
val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
val columns = (bucketCols ++ sortCols).zip(types).map {
case (colName, dt) => col(colName).cast(dt)
}

for (row <- rows) {
val actualBucketId = getBucketId(row).getInt(0)
assert(actualBucketId == bucketId)
}
// Read the bucket file into a dataframe, so that it's easier to test.
val readBack = sqlContext.read.format(source)
.load(bucketFile.getAbsolutePath)
.select(columns: _*)

// If we specified sort columns while writing bucket table, make sure the data in this
// bucket file is already sorted.
if (sortCols.nonEmpty) {
checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
}

// Go through all rows in this bucket file, calculate bucket id according to bucket column
// values, and make sure it equals to the expected bucket id that inferred from file name.
val qe = readBack.select(bucketCols.map(col): _*).queryExecution
val rows = qe.toRdd.map(_.copy()).collect()
val getBucketId = UnsafeProjection.create(
HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil,
qe.analyzed.output)

for (row <- rows) {
val actualBucketId = getBucketId(row).getInt(0)
assert(actualBucketId == bucketId)
}
}
}
Expand All @@ -113,7 +129,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle

val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
for (i <- 0 until 5) {
testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k"))
testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
}
}
}
Expand All @@ -131,7 +147,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle

val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
for (i <- 0 until 5) {
testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k"))
testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k"))
}
}
}
Expand All @@ -146,7 +162,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
.saveAsTable("bucketed_table")

val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
testBucketing(tableDir, source, Seq("i", "j"))
testBucketing(tableDir, source, 8, Seq("i", "j"))
}
}
}
Expand All @@ -161,7 +177,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
.saveAsTable("bucketed_table")

val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
testBucketing(tableDir, source, Seq("i", "j"), Seq("k"))
testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k"))
}
}
}
Expand Down