From 99b3ed1696f0cf31f2bf1221f2e6a059364ad6f4 Mon Sep 17 00:00:00 2001 From: Bo Meng Date: Wed, 3 Jun 2015 14:04:16 -0700 Subject: [PATCH] only add required nonKeyColumns to the scan --- .../spark/sql/hbase/HBaseRelation.scala | 10 ++++- .../spark/sql/hbase/TpcMiniTestSuite.scala | 44 +++++++++++++++++-- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/sql/hbase/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala b/sql/hbase/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala index 291dadb72de50..dfca510415f2f 100755 --- a/sql/hbase/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala +++ b/sql/hbase/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala @@ -755,7 +755,7 @@ private[hbase] case class HBaseRelation( var distinctProjectionList = projectionList.map(_.name) if (otherFilters.isDefined) { distinctProjectionList = - distinctProjectionList.union(otherFilters.get.references.toSeq.map(_.name)) + distinctProjectionList.union(otherFilters.get.references.toSeq.map(_.name)).distinct } // filter out the key columns distinctProjectionList = @@ -812,13 +812,19 @@ private[hbase] case class HBaseRelation( // to avoid a full projection distinctProjectionList = pushdownNameSet.toSeq.distinct if (distinctProjectionList.nonEmpty && distinctProjectionList.size < nonKeyColumns.size) { - distinctProjectionList.map { + distinctProjectionList.foreach { case p => val nkc = nonKeyColumns.find(_.sqlName == p).get scan.addColumn(nkc.familyRaw, nkc.qualifierRaw) } } } + } else if (otherFilters.isDefined) { + distinctProjectionList.foreach { + case p => + val nkc = nonKeyColumns.find(_.sqlName == p).get + scan.addColumn(nkc.familyRaw, nkc.qualifierRaw) + } } scan } diff --git a/sql/hbase/src/test/scala/org/apache/spark/sql/hbase/TpcMiniTestSuite.scala b/sql/hbase/src/test/scala/org/apache/spark/sql/hbase/TpcMiniTestSuite.scala index 7ccbe04d45bd0..53e256be76542 100644 --- a/sql/hbase/src/test/scala/org/apache/spark/sql/hbase/TpcMiniTestSuite.scala +++ b/sql/hbase/src/test/scala/org/apache/spark/sql/hbase/TpcMiniTestSuite.scala @@ -252,30 +252,66 @@ class TpcMiniTestSuite extends HBaseIntegrationTestBase { } test("Query 17") { - val sql = "SELECT count(ss_customer_sk) as count_customer FROM store_sales WHERE ss_customer_sk > 100" + val sql = "SELECT count(ss_customer_sk) AS count_customer FROM store_sales WHERE ss_customer_sk > 100" val rows = runSql(sql) assert(rows(0).get(0) == 83) } test("Query 18") { + val sql = "SELECT ss_quantity, ss_wholesale_cost, ss_list_price FROM store_sales WHERE ss_ticket_number = 3" + val rows = runSql(sql) + assert(rows.length == 14) + } + + test("Query 19") { + val sql = "SELECT ss_sold_date_sk, ss_sold_time_sk, ss_store_sk FROM store_sales WHERE ss_ticket_number = 3" + val rows = runSql(sql) + assert(rows.length == 14) + } + + test("Query 20") { + val sql = "SELECT ss_customer_sk, ss_promo_sk, ss_coupon_amt FROM store_sales WHERE ss_ticket_number = 3" + val rows = runSql(sql) + assert(rows.length == 14) + } + + test("Query 21") { + val sql = "SELECT ss_item_sk, ss_ticket_number, count(1) FROM store_sales WHERE ss_ticket_number >= 3 and ss_ticket_number <= 4 group by ss_item_sk, ss_ticket_number" + val rows = runSql(sql) + assert(rows.length == 24) + } + + test("Query 22") { + val sql = "SELECT ss_item_sk, ss_ticket_number, SUM(ss_wholesale_cost) AS sum_wholesale_cost FROM store_sales WHERE ss_ticket_number >= 3 AND ss_ticket_number <= 4 group by ss_item_sk, ss_ticket_number" + val rows = runSql(sql) + assert(rows.length == 23) + } + + test("Query 23") { + val sql = "SELECT ss_item_sk, ss_ticket_number, min(ss_wholesale_cost) as min_wholesale_cost, max(ss_wholesale_cost) as max_wholesale_cost, avg(ss_wholesale_cost) as avg_wholesale_cost FROM store_sales WHERE ss_ticket_number >= 3 and ss_ticket_number <= 3 GROUP BY ss_item_sk, ss_ticket_number" + val rows = runSql(sql) + assert(rows.length == 13) + } + + test("Query 24") { val sql = "SELECT ss_item_sk, ss_ticket_number FROM store_sales WHERE (ss_item_sk = 186 AND ss_ticket_number > 0)" val rows = runSql(sql) assert(rows.length == 1) } - test("Query 19") { + test("Query 25") { val sql = "SELECT * FROM store_sales WHERE ss_ticket_number > 6 and ss_sold_date_sk > 0" val rows = runSql(sql) assert(rows.length == 21) } - test("Query 20") { + test("Query 26") { val sql = "SELECT * FROM store_sales WHERE ss_ticket_number = 7 and ss_sold_date_sk > 0" val rows = runSql(sql) assert(rows.length == 12) } - test("Query 21") { + test("Query 27") { val sql = "SELECT * FROM store_sales WHERE ss_ticket_number + 0 = 3 and ss_sold_date_sk + 0 > 0" val rows = runSql(sql) assert(rows.length == 13)