1717
1818package edu .berkeley .cs .rise .opaque .benchmark
1919
20+ import scala .io .Source
21+
2022import org .apache .spark .sql .DataFrame
21- import org .apache .spark .sql .functions ._
2223import org .apache .spark .sql .types ._
2324import org .apache .spark .sql .SQLContext
2425
26+ import edu .berkeley .cs .rise .opaque .Utils
27+
2528object TPCH {
29+
30+ val tableNames = Seq (" part" , " supplier" , " lineitem" , " partsupp" , " orders" , " nation" , " region" , " customer" )
31+
2632 def part (
27- sqlContext : SQLContext , securityLevel : SecurityLevel , size : String , numPartitions : Int )
33+ sqlContext : SQLContext , size : String )
2834 : DataFrame =
29- securityLevel.applyTo(
3035 sqlContext.read.schema(
3136 StructType (Seq (
3237 StructField (" p_partkey" , IntegerType ),
@@ -41,12 +46,10 @@ object TPCH {
4146 .format(" csv" )
4247 .option(" delimiter" , " |" )
4348 .load(s " ${Benchmark .dataDir}/tpch/ $size/part.tbl " )
44- .repartition(numPartitions))
4549
4650 def supplier (
47- sqlContext : SQLContext , securityLevel : SecurityLevel , size : String , numPartitions : Int )
51+ sqlContext : SQLContext , size : String )
4852 : DataFrame =
49- securityLevel.applyTo(
5053 sqlContext.read.schema(
5154 StructType (Seq (
5255 StructField (" s_suppkey" , IntegerType ),
@@ -59,12 +62,10 @@ object TPCH {
5962 .format(" csv" )
6063 .option(" delimiter" , " |" )
6164 .load(s " ${Benchmark .dataDir}/tpch/ $size/supplier.tbl " )
62- .repartition(numPartitions))
6365
6466 def lineitem (
65- sqlContext : SQLContext , securityLevel : SecurityLevel , size : String , numPartitions : Int )
67+ sqlContext : SQLContext , size : String )
6668 : DataFrame =
67- securityLevel.applyTo(
6869 sqlContext.read.schema(
6970 StructType (Seq (
7071 StructField (" l_orderkey" , IntegerType ),
@@ -86,12 +87,10 @@ object TPCH {
8687 .format(" csv" )
8788 .option(" delimiter" , " |" )
8889 .load(s " ${Benchmark .dataDir}/tpch/ $size/lineitem.tbl " )
89- .repartition(numPartitions))
9090
9191 def partsupp (
92- sqlContext : SQLContext , securityLevel : SecurityLevel , size : String , numPartitions : Int )
92+ sqlContext : SQLContext , size : String )
9393 : DataFrame =
94- securityLevel.applyTo(
9594 sqlContext.read.schema(
9695 StructType (Seq (
9796 StructField (" ps_partkey" , IntegerType ),
@@ -102,12 +101,10 @@ object TPCH {
102101 .format(" csv" )
103102 .option(" delimiter" , " |" )
104103 .load(s " ${Benchmark .dataDir}/tpch/ $size/partsupp.tbl " )
105- .repartition(numPartitions))
106104
107105 def orders (
108- sqlContext : SQLContext , securityLevel : SecurityLevel , size : String , numPartitions : Int )
106+ sqlContext : SQLContext , size : String )
109107 : DataFrame =
110- securityLevel.applyTo(
111108 sqlContext.read.schema(
112109 StructType (Seq (
113110 StructField (" o_orderkey" , IntegerType ),
@@ -122,12 +119,10 @@ object TPCH {
122119 .format(" csv" )
123120 .option(" delimiter" , " |" )
124121 .load(s " ${Benchmark .dataDir}/tpch/ $size/orders.tbl " )
125- .repartition(numPartitions))
126122
127123 def nation (
128- sqlContext : SQLContext , securityLevel : SecurityLevel , size : String , numPartitions : Int )
124+ sqlContext : SQLContext , size : String )
129125 : DataFrame =
130- securityLevel.applyTo(
131126 sqlContext.read.schema(
132127 StructType (Seq (
133128 StructField (" n_nationkey" , IntegerType ),
@@ -137,60 +132,85 @@ object TPCH {
137132 .format(" csv" )
138133 .option(" delimiter" , " |" )
139134 .load(s " ${Benchmark .dataDir}/tpch/ $size/nation.tbl " )
140- .repartition(numPartitions))
141-
142-
143- private def tpch9EncryptedDFs (
144- sqlContext : SQLContext , securityLevel : SecurityLevel , size : String , numPartitions : Int )
145- : (DataFrame , DataFrame , DataFrame , DataFrame , DataFrame , DataFrame ) = {
146- val partDF = part(sqlContext, securityLevel, size, numPartitions)
147- val supplierDF = supplier(sqlContext, securityLevel, size, numPartitions)
148- val lineitemDF = lineitem(sqlContext, securityLevel, size, numPartitions)
149- val partsuppDF = partsupp(sqlContext, securityLevel, size, numPartitions)
150- val ordersDF = orders(sqlContext, securityLevel, size, numPartitions)
151- val nationDF = nation(sqlContext, securityLevel, size, numPartitions)
152- (partDF, supplierDF, lineitemDF, partsuppDF, ordersDF, nationDF)
135+
136+ def region (
137+ sqlContext : SQLContext , size : String )
138+ : DataFrame =
139+ sqlContext.read.schema(
140+ StructType (Seq (
141+ StructField (" r_regionkey" , IntegerType ),
142+ StructField (" r_name" , StringType ),
143+ StructField (" r_comment" , StringType ))))
144+ .format(" csv" )
145+ .option(" delimiter" , " |" )
146+ .load(s " ${Benchmark .dataDir}/tpch/ $size/region.tbl " )
147+
148+ def customer (
149+ sqlContext : SQLContext , size : String )
150+ : DataFrame =
151+ sqlContext.read.schema(
152+ StructType (Seq (
153+ StructField (" c_custkey" , IntegerType ),
154+ StructField (" c_name" , StringType ),
155+ StructField (" c_address" , StringType ),
156+ StructField (" c_nationkey" , IntegerType ),
157+ StructField (" c_phone" , StringType ),
158+ StructField (" c_acctbal" , FloatType ),
159+ StructField (" c_mktsegment" , StringType ),
160+ StructField (" c_comment" , StringType ))))
161+ .format(" csv" )
162+ .option(" delimiter" , " |" )
163+ .load(s " ${Benchmark .dataDir}/tpch/ $size/customer.tbl " )
164+
165+ def generateMap (
166+ sqlContext : SQLContext , size : String )
167+ : Map [String , DataFrame ] = {
168+ Map (" part" -> part(sqlContext, size),
169+ " supplier" -> supplier(sqlContext, size),
170+ " lineitem" -> lineitem(sqlContext, size),
171+ " partsupp" -> partsupp(sqlContext, size),
172+ " orders" -> orders(sqlContext, size),
173+ " nation" -> nation(sqlContext, size),
174+ " region" -> region(sqlContext, size),
175+ " customer" -> customer(sqlContext, size)
176+ ),
177+ }
178+
179+ def apply (sqlContext : SQLContext , size : String ) : TPCH = {
180+ val tpch = new TPCH (sqlContext, size)
181+ tpch.tableNames = tableNames
182+ tpch.nameToDF = generateMap(sqlContext, size)
183+ tpch.ensureCached()
184+ tpch
185+ }
186+ }
187+
188+ class TPCH (val sqlContext : SQLContext , val size : String ) {
189+
190+ var tableNames : Seq [String ] = Seq ()
191+ var nameToDF : Map [String , DataFrame ] = Map ()
192+
193+ def ensureCached () = {
194+ for (name <- tableNames) {
195+ nameToDF.get(name).foreach(df => {
196+ Utils .ensureCached(df)
197+ Utils .ensureCached(Encrypted .applyTo(df))
198+ })
199+ }
200+ }
201+
202+ def setupViews (securityLevel : SecurityLevel , numPartitions : Int ) = {
203+ for ((name, df) <- nameToDF) {
204+ securityLevel.applyTo(df.repartition(numPartitions)).createOrReplaceTempView(name)
205+ }
153206 }
154207
155- /** TPC-H query 9 - Product Type Profit Measure Query */
156- def tpch9 (
157- sqlContext : SQLContext ,
158- securityLevel : SecurityLevel ,
159- size : String ,
160- numPartitions : Int ,
161- quantityThreshold : Option [Int ] = None ) : DataFrame = {
162- import sqlContext .implicits ._
163- val (partDF, supplierDF, lineitemDF, partsuppDF, ordersDF, nationDF) =
164- tpch9EncryptedDFs(sqlContext, securityLevel, size, numPartitions)
165-
166- val df =
167- ordersDF.select($" o_orderkey" , year($" o_orderdate" ).as(" o_year" )) // 6. orders
168- .join(
169- (nationDF// 4. nation
170- .join(
171- supplierDF // 3. supplier
172- .join(
173- partDF // 1. part
174- .filter($" p_name" .contains(" maroon" ))
175- .select($" p_partkey" )
176- .join(partsuppDF, $" p_partkey" === $" ps_partkey" ), // 2. partsupp
177- $" ps_suppkey" === $" s_suppkey" ),
178- $" s_nationkey" === $" n_nationkey" ))
179- .join(
180- // 5. lineitem
181- quantityThreshold match {
182- case Some (q) => lineitemDF.filter($" l_quantity" > lit(q))
183- case None => lineitemDF
184- },
185- $" s_suppkey" === $" l_suppkey" && $" p_partkey" === $" l_partkey" ),
186- $" l_orderkey" === $" o_orderkey" )
187- .select(
188- $" n_name" ,
189- $" o_year" ,
190- ($" l_extendedprice" * (lit(1 ) - $" l_discount" ) - $" ps_supplycost" * $" l_quantity" )
191- .as(" amount" ))
192- .groupBy(" n_name" , " o_year" ).agg(sum($" amount" ).as(" sum_profit" ))
193-
194- df
195- }
208+ def query (queryNumber : Int , securityLevel : SecurityLevel , sqlContext : SQLContext , numPartitions : Int ) : DataFrame = {
209+ setupViews(securityLevel, numPartitions)
210+
211+ val queryLocation = sys.env.getOrElse(" OPAQUE_HOME" , " ." ) + " /src/test/resources/tpch/"
212+ val sqlStr = Source .fromFile(queryLocation + s " q $queryNumber.sql " ).getLines().mkString(" \n " )
213+
214+ sqlContext.sparkSession.sql(sqlStr)
215+ }
196216}
0 commit comments