Skip to content

Commit 0a20d71

Browse files
authored
TPC-H test suite added (#136)
* added tpch sql files * functions updated to save temp view * main function skeleton done * load and clear done * fix clear * performQuery done * import cleanup, use OPAQUE_HOME * TPC-H 9 refactored to use SQL rather than DF operations * removed : Unit, unused imports * added TestUtils.scala * moved all common initialization to TestUtils * update name * begin rewriting TPCH.scala to store persistent tables * invalid table name error * TPCH conversion to class started * compiles * added second case, cleared up names * added TPC-H 6 to check that persistent state has no issues * added functions for the last two tables * addressed most logic changes * DataFrame only loaded once * apply method in companion object * full test suite added * added testFunc parameter to testAgainstSpark * ignore #18
1 parent 6031a4a commit 0a20d71

File tree

26 files changed

+1018
-161
lines changed

26 files changed

+1018
-161
lines changed

src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@
1717

1818
package edu.berkeley.cs.rise.opaque.benchmark
1919

20+
import scala.io.Source
21+
2022
import org.apache.spark.sql.DataFrame
21-
import org.apache.spark.sql.functions._
2223
import org.apache.spark.sql.types._
2324
import org.apache.spark.sql.SQLContext
2425

26+
import edu.berkeley.cs.rise.opaque.Utils
27+
2528
object TPCH {
29+
30+
val tableNames = Seq("part", "supplier", "lineitem", "partsupp", "orders", "nation", "region", "customer")
31+
2632
def part(
27-
sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int)
33+
sqlContext: SQLContext, size: String)
2834
: DataFrame =
29-
securityLevel.applyTo(
3035
sqlContext.read.schema(
3136
StructType(Seq(
3237
StructField("p_partkey", IntegerType),
@@ -41,12 +46,10 @@ object TPCH {
4146
.format("csv")
4247
.option("delimiter", "|")
4348
.load(s"${Benchmark.dataDir}/tpch/$size/part.tbl")
44-
.repartition(numPartitions))
4549

4650
def supplier(
47-
sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int)
51+
sqlContext: SQLContext, size: String)
4852
: DataFrame =
49-
securityLevel.applyTo(
5053
sqlContext.read.schema(
5154
StructType(Seq(
5255
StructField("s_suppkey", IntegerType),
@@ -59,12 +62,10 @@ object TPCH {
5962
.format("csv")
6063
.option("delimiter", "|")
6164
.load(s"${Benchmark.dataDir}/tpch/$size/supplier.tbl")
62-
.repartition(numPartitions))
6365

6466
def lineitem(
65-
sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int)
67+
sqlContext: SQLContext, size: String)
6668
: DataFrame =
67-
securityLevel.applyTo(
6869
sqlContext.read.schema(
6970
StructType(Seq(
7071
StructField("l_orderkey", IntegerType),
@@ -86,12 +87,10 @@ object TPCH {
8687
.format("csv")
8788
.option("delimiter", "|")
8889
.load(s"${Benchmark.dataDir}/tpch/$size/lineitem.tbl")
89-
.repartition(numPartitions))
9090

9191
def partsupp(
92-
sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int)
92+
sqlContext: SQLContext, size: String)
9393
: DataFrame =
94-
securityLevel.applyTo(
9594
sqlContext.read.schema(
9695
StructType(Seq(
9796
StructField("ps_partkey", IntegerType),
@@ -102,12 +101,10 @@ object TPCH {
102101
.format("csv")
103102
.option("delimiter", "|")
104103
.load(s"${Benchmark.dataDir}/tpch/$size/partsupp.tbl")
105-
.repartition(numPartitions))
106104

107105
def orders(
108-
sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int)
106+
sqlContext: SQLContext, size: String)
109107
: DataFrame =
110-
securityLevel.applyTo(
111108
sqlContext.read.schema(
112109
StructType(Seq(
113110
StructField("o_orderkey", IntegerType),
@@ -122,12 +119,10 @@ object TPCH {
122119
.format("csv")
123120
.option("delimiter", "|")
124121
.load(s"${Benchmark.dataDir}/tpch/$size/orders.tbl")
125-
.repartition(numPartitions))
126122

127123
def nation(
128-
sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int)
124+
sqlContext: SQLContext, size: String)
129125
: DataFrame =
130-
securityLevel.applyTo(
131126
sqlContext.read.schema(
132127
StructType(Seq(
133128
StructField("n_nationkey", IntegerType),
@@ -137,60 +132,85 @@ object TPCH {
137132
.format("csv")
138133
.option("delimiter", "|")
139134
.load(s"${Benchmark.dataDir}/tpch/$size/nation.tbl")
140-
.repartition(numPartitions))
141-
142-
143-
private def tpch9EncryptedDFs(
144-
sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int)
145-
: (DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame) = {
146-
val partDF = part(sqlContext, securityLevel, size, numPartitions)
147-
val supplierDF = supplier(sqlContext, securityLevel, size, numPartitions)
148-
val lineitemDF = lineitem(sqlContext, securityLevel, size, numPartitions)
149-
val partsuppDF = partsupp(sqlContext, securityLevel, size, numPartitions)
150-
val ordersDF = orders(sqlContext, securityLevel, size, numPartitions)
151-
val nationDF = nation(sqlContext, securityLevel, size, numPartitions)
152-
(partDF, supplierDF, lineitemDF, partsuppDF, ordersDF, nationDF)
135+
136+
def region(
137+
sqlContext: SQLContext, size: String)
138+
: DataFrame =
139+
sqlContext.read.schema(
140+
StructType(Seq(
141+
StructField("r_regionkey", IntegerType),
142+
StructField("r_name", StringType),
143+
StructField("r_comment", StringType))))
144+
.format("csv")
145+
.option("delimiter", "|")
146+
.load(s"${Benchmark.dataDir}/tpch/$size/region.tbl")
147+
148+
def customer(
149+
sqlContext: SQLContext, size: String)
150+
: DataFrame =
151+
sqlContext.read.schema(
152+
StructType(Seq(
153+
StructField("c_custkey", IntegerType),
154+
StructField("c_name", StringType),
155+
StructField("c_address", StringType),
156+
StructField("c_nationkey", IntegerType),
157+
StructField("c_phone", StringType),
158+
StructField("c_acctbal", FloatType),
159+
StructField("c_mktsegment", StringType),
160+
StructField("c_comment", StringType))))
161+
.format("csv")
162+
.option("delimiter", "|")
163+
.load(s"${Benchmark.dataDir}/tpch/$size/customer.tbl")
164+
165+
def generateMap(
166+
sqlContext: SQLContext, size: String)
167+
: Map[String, DataFrame] = {
168+
Map("part" -> part(sqlContext, size),
169+
"supplier" -> supplier(sqlContext, size),
170+
"lineitem" -> lineitem(sqlContext, size),
171+
"partsupp" -> partsupp(sqlContext, size),
172+
"orders" -> orders(sqlContext, size),
173+
"nation" -> nation(sqlContext, size),
174+
"region" -> region(sqlContext, size),
175+
"customer" -> customer(sqlContext, size)
176+
),
177+
}
178+
179+
def apply(sqlContext: SQLContext, size: String) : TPCH = {
180+
val tpch = new TPCH(sqlContext, size)
181+
tpch.tableNames = tableNames
182+
tpch.nameToDF = generateMap(sqlContext, size)
183+
tpch.ensureCached()
184+
tpch
185+
}
186+
}
187+
188+
class TPCH(val sqlContext: SQLContext, val size: String) {
189+
190+
var tableNames : Seq[String] = Seq()
191+
var nameToDF : Map[String, DataFrame] = Map()
192+
193+
def ensureCached() = {
194+
for (name <- tableNames) {
195+
nameToDF.get(name).foreach(df => {
196+
Utils.ensureCached(df)
197+
Utils.ensureCached(Encrypted.applyTo(df))
198+
})
199+
}
200+
}
201+
202+
def setupViews(securityLevel: SecurityLevel, numPartitions: Int) = {
203+
for ((name, df) <- nameToDF) {
204+
securityLevel.applyTo(df.repartition(numPartitions)).createOrReplaceTempView(name)
205+
}
153206
}
154207

155-
/** TPC-H query 9 - Product Type Profit Measure Query */
156-
def tpch9(
157-
sqlContext: SQLContext,
158-
securityLevel: SecurityLevel,
159-
size: String,
160-
numPartitions: Int,
161-
quantityThreshold: Option[Int] = None) : DataFrame = {
162-
import sqlContext.implicits._
163-
val (partDF, supplierDF, lineitemDF, partsuppDF, ordersDF, nationDF) =
164-
tpch9EncryptedDFs(sqlContext, securityLevel, size, numPartitions)
165-
166-
val df =
167-
ordersDF.select($"o_orderkey", year($"o_orderdate").as("o_year")) // 6. orders
168-
.join(
169-
(nationDF// 4. nation
170-
.join(
171-
supplierDF // 3. supplier
172-
.join(
173-
partDF // 1. part
174-
.filter($"p_name".contains("maroon"))
175-
.select($"p_partkey")
176-
.join(partsuppDF, $"p_partkey" === $"ps_partkey"), // 2. partsupp
177-
$"ps_suppkey" === $"s_suppkey"),
178-
$"s_nationkey" === $"n_nationkey"))
179-
.join(
180-
// 5. lineitem
181-
quantityThreshold match {
182-
case Some(q) => lineitemDF.filter($"l_quantity" > lit(q))
183-
case None => lineitemDF
184-
},
185-
$"s_suppkey" === $"l_suppkey" && $"p_partkey" === $"l_partkey"),
186-
$"l_orderkey" === $"o_orderkey")
187-
.select(
188-
$"n_name",
189-
$"o_year",
190-
($"l_extendedprice" * (lit(1) - $"l_discount") - $"ps_supplycost" * $"l_quantity")
191-
.as("amount"))
192-
.groupBy("n_name", "o_year").agg(sum($"amount").as("sum_profit"))
193-
194-
df
195-
}
208+
def query(queryNumber: Int, securityLevel: SecurityLevel, sqlContext: SQLContext, numPartitions: Int) : DataFrame = {
209+
setupViews(securityLevel, numPartitions)
210+
211+
val queryLocation = sys.env.getOrElse("OPAQUE_HOME", ".") + "/src/test/resources/tpch/"
212+
val sqlStr = Source.fromFile(queryLocation + s"q$queryNumber.sql").getLines().mkString("\n")
213+
214+
sqlContext.sparkSession.sql(sqlStr)
215+
}
196216
}

src/test/resources/tpch/q1.sql

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
-- using default substitutions
2+
3+
select
4+
l_returnflag,
5+
l_linestatus,
6+
sum(l_quantity) as sum_qty,
7+
sum(l_extendedprice) as sum_base_price,
8+
sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
9+
sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
10+
avg(l_quantity) as avg_qty,
11+
avg(l_extendedprice) as avg_price,
12+
avg(l_discount) as avg_disc,
13+
count(*) as count_order
14+
from
15+
lineitem
16+
where
17+
l_shipdate <= date '1998-12-01' - interval '90' day
18+
group by
19+
l_returnflag,
20+
l_linestatus
21+
order by
22+
l_returnflag,
23+
l_linestatus

src/test/resources/tpch/q10.sql

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
-- using default substitutions
2+
3+
select
4+
c_custkey,
5+
c_name,
6+
sum(l_extendedprice * (1 - l_discount)) as revenue,
7+
c_acctbal,
8+
n_name,
9+
c_address,
10+
c_phone,
11+
c_comment
12+
from
13+
customer,
14+
orders,
15+
lineitem,
16+
nation
17+
where
18+
c_custkey = o_custkey
19+
and l_orderkey = o_orderkey
20+
and o_orderdate >= date '1993-10-01'
21+
and o_orderdate < date '1993-10-01' + interval '3' month
22+
and l_returnflag = 'R'
23+
and c_nationkey = n_nationkey
24+
group by
25+
c_custkey,
26+
c_name,
27+
c_acctbal,
28+
c_phone,
29+
n_name,
30+
c_address,
31+
c_comment
32+
order by
33+
revenue desc
34+
limit 20

src/test/resources/tpch/q11.sql

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
-- using default substitutions
2+
3+
select
4+
ps_partkey,
5+
sum(ps_supplycost * ps_availqty) as value
6+
from
7+
partsupp,
8+
supplier,
9+
nation
10+
where
11+
ps_suppkey = s_suppkey
12+
and s_nationkey = n_nationkey
13+
and n_name = 'GERMANY'
14+
group by
15+
ps_partkey having
16+
sum(ps_supplycost * ps_availqty) > (
17+
select
18+
sum(ps_supplycost * ps_availqty) * 0.0001000000
19+
from
20+
partsupp,
21+
supplier,
22+
nation
23+
where
24+
ps_suppkey = s_suppkey
25+
and s_nationkey = n_nationkey
26+
and n_name = 'GERMANY'
27+
)
28+
order by
29+
value desc

src/test/resources/tpch/q12.sql

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
-- using default substitutions
2+
3+
select
4+
l_shipmode,
5+
sum(case
6+
when o_orderpriority = '1-URGENT'
7+
or o_orderpriority = '2-HIGH'
8+
then 1
9+
else 0
10+
end) as high_line_count,
11+
sum(case
12+
when o_orderpriority <> '1-URGENT'
13+
and o_orderpriority <> '2-HIGH'
14+
then 1
15+
else 0
16+
end) as low_line_count
17+
from
18+
orders,
19+
lineitem
20+
where
21+
o_orderkey = l_orderkey
22+
and l_shipmode in ('MAIL', 'SHIP')
23+
and l_commitdate < l_receiptdate
24+
and l_shipdate < l_commitdate
25+
and l_receiptdate >= date '1994-01-01'
26+
and l_receiptdate < date '1994-01-01' + interval '1' year
27+
group by
28+
l_shipmode
29+
order by
30+
l_shipmode

src/test/resources/tpch/q13.sql

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
-- using default substitutions
2+
3+
select
4+
c_count,
5+
count(*) as custdist
6+
from
7+
(
8+
select
9+
c_custkey,
10+
count(o_orderkey) as c_count
11+
from
12+
customer left outer join orders on
13+
c_custkey = o_custkey
14+
and o_comment not like '%special%requests%'
15+
group by
16+
c_custkey
17+
) as c_orders
18+
group by
19+
c_count
20+
order by
21+
custdist desc,
22+
c_count desc

src/test/resources/tpch/q14.sql

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- using default substitutions
2+
3+
select
4+
100.00 * sum(case
5+
when p_type like 'PROMO%'
6+
then l_extendedprice * (1 - l_discount)
7+
else 0
8+
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
9+
from
10+
lineitem,
11+
part
12+
where
13+
l_partkey = p_partkey
14+
and l_shipdate >= date '1995-09-01'
15+
and l_shipdate < date '1995-09-01' + interval '1' month

0 commit comments

Comments
 (0)