Skip to content

Commit 12c86df

Browse files
committed
Add tables() to SQLContext to return a DataFrame containing existing tables.
1 parent 44b2311 commit 12c86df

File tree

6 files changed

+247
-0
lines changed

6 files changed

+247
-0
lines changed

python/pyspark/sql/context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,24 @@ def table(self, tableName):
621621
"""
622622
return DataFrame(self._ssql_ctx.table(tableName), self)
623623

624+
def tables(self, dbName=None):
625+
"""Returns a DataFrame containing names of table in the given database.
626+
627+
If `dbName` is `None`, the database will be the current database.
628+
629+
The returned DataFrame has two columns, tableName and isTemporary
630+
(a column with BooleanType indicating if a table is a temporary one or not).
631+
632+
>>> sqlCtx.registerRDDAsTable(df, "table1")
633+
>>> df2 = sqlCtx.tables()
634+
>>> df2.first()
635+
Row(tableName=u'table1', isTemporary=True)
636+
"""
637+
if dbName is None:
638+
return DataFrame(self._ssql_ctx.tables(), self)
639+
else:
640+
return DataFrame(self._ssql_ctx.tables(dbName), self)
641+
624642
def cacheTable(self, tableName):
625643
"""Caches the specified table in-memory."""
626644
self._ssql_ctx.cacheTable(tableName)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ trait Catalog {
3434
tableIdentifier: Seq[String],
3535
alias: Option[String] = None): LogicalPlan
3636

37+
/**
38+
* Returns names and flags indicating if a table is temporary or not of all tables in the
39+
* database identified by `databaseIdentifier`.
40+
*/
41+
def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)]
42+
3743
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit
3844

3945
def unregisterTable(tableIdentifier: Seq[String]): Unit
@@ -60,6 +66,10 @@ trait Catalog {
6066
protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
6167
(tableIdent.lift(tableIdent.size - 2), tableIdent.last)
6268
}
69+
70+
protected def getDBName(databaseIdentifier: Seq[String]): Option[String] = {
71+
databaseIdentifier.lift(databaseIdentifier.size - 1)
72+
}
6373
}
6474

6575
class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
@@ -101,6 +111,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
101111
// properly qualified with this alias.
102112
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
103113
}
114+
115+
override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
116+
tables.map {
117+
case (name, _) => (name, true)
118+
}.toSeq
119+
}
104120
}
105121

106122
/**
@@ -137,6 +153,22 @@ trait OverrideCatalog extends Catalog {
137153
withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
138154
}
139155

156+
abstract override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
157+
val dbName = getDBName(databaseIdentifier)
158+
val temporaryTables = overrides.filter {
159+
// If a temporary table does not have an associated database, we should return its name.
160+
case ((None, _), _) => true
161+
// If a temporary table does have an associated database, we should return it if the database
162+
// matches the given database name.
163+
case ((db: Some[String], _), _) if db == dbName => true
164+
case _ => false
165+
}.map {
166+
case ((_, tableName), _) => (tableName, true)
167+
}.toSeq
168+
169+
temporaryTables ++ super.getTables(databaseIdentifier)
170+
}
171+
140172
override def registerTable(
141173
tableIdentifier: Seq[String],
142174
plan: LogicalPlan): Unit = {
@@ -172,6 +204,10 @@ object EmptyCatalog extends Catalog {
172204
throw new UnsupportedOperationException
173205
}
174206

207+
override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
208+
throw new UnsupportedOperationException
209+
}
210+
175211
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
176212
throw new UnsupportedOperationException
177213
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,24 @@ class SQLContext(@transient val sparkContext: SparkContext)
734734
def table(tableName: String): DataFrame =
735735
DataFrame(this, catalog.lookupRelation(Seq(tableName)))
736736

737+
/**
738+
* Returns a [[DataFrame]] containing names of existing tables in the current database.
739+
* The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
740+
* indicating if a table is a temporary one or not).
741+
*/
742+
def tables(databaseName: String): DataFrame = {
743+
createDataFrame(catalog.getTables(Seq(databaseName))).toDataFrame("tableName", "isTemporary")
744+
}
745+
746+
/**
747+
* Returns a [[DataFrame]] containing names of existing tables in the given database.
748+
* The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
749+
* indicating if a table is a temporary one or not).
750+
*/
751+
def tables(): DataFrame = {
752+
createDataFrame(catalog.getTables(Seq.empty[String])).toDataFrame("tableName", "isTemporary")
753+
}
754+
737755
protected[sql] class SparkPlanner extends SparkStrategies {
738756
val sparkContext: SparkContext = self.sparkContext
739757

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.scalatest.BeforeAndAfterAll
21+
22+
import org.apache.spark.sql.test.TestSQLContext
23+
import org.apache.spark.sql.test.TestSQLContext._
24+
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
25+
26+
class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
27+
28+
import org.apache.spark.sql.test.TestSQLContext.implicits._
29+
30+
val df =
31+
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
32+
33+
override def beforeAll(): Unit = {
34+
(1 to 10).foreach(i => df.registerTempTable(s"table$i"))
35+
}
36+
37+
override def afterAll(): Unit = {
38+
catalog.unregisterAllTables()
39+
}
40+
41+
test("get All Tables") {
42+
checkAnswer(tables(), (1 to 10).map(i => Row(s"table$i", true)))
43+
}
44+
45+
test("getting All Tables with a database name has not impact on returned table names") {
46+
checkAnswer(tables("DB"), (1 to 10).map(i => Row(s"table$i", true)))
47+
}
48+
49+
test("query the returned DataFrame of tables") {
50+
val tableDF = tables()
51+
val schema = StructType(
52+
StructField("tableName", StringType, true) ::
53+
StructField("isTemporary", BooleanType, false) :: Nil)
54+
assert(schema === tableDF.schema)
55+
56+
checkAnswer(
57+
tableDF.select("tableName"),
58+
(1 to 10).map(i => Row(s"table$i"))
59+
)
60+
61+
tableDF.registerTempTable("tables")
62+
checkAnswer(
63+
sql("SELECT isTemporary, tableName from tables WHERE isTemporary"),
64+
(1 to 10).map(i => Row(true, s"table$i"))
65+
)
66+
checkAnswer(
67+
tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
68+
Row("tables", true))
69+
dropTempTable("tables")
70+
}
71+
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
198198
}
199199
}
200200

201+
override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
202+
val dbName = getDBName(databaseIdentifier).getOrElse(hive.sessionState.getCurrentDatabase)
203+
client.getAllTables(dbName).map(tableName => (tableName, false))
204+
}
205+
201206
/**
202207
* Create table with specified database, table name, table description and schema
203208
* @param databaseName Database Name
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive
19+
20+
import org.scalatest.BeforeAndAfterAll
21+
22+
import org.apache.spark.sql.hive.test.TestHive
23+
import org.apache.spark.sql.hive.test.TestHive._
24+
import org.apache.spark.sql.QueryTest
25+
import org.apache.spark.sql.Row
26+
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
27+
28+
class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
29+
30+
import org.apache.spark.sql.hive.test.TestHive.implicits._
31+
32+
val sqlContext = TestHive
33+
val df =
34+
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
35+
36+
override def beforeAll(): Unit = {
37+
// The catalog in HiveContext is a case insensitive one.
38+
(1 to 10).foreach(i => catalog.registerTable(Seq(s"Table$i"), df.logicalPlan))
39+
(1 to 10).foreach(i => catalog.registerTable(Seq("db1", s"db1TempTable$i"), df.logicalPlan))
40+
(1 to 10).foreach {
41+
i => sql(s"CREATE TABLE hivetable$i (key int, value string)")
42+
}
43+
sql("CREATE DATABASE IF NOT EXISTS db1")
44+
(1 to 10).foreach {
45+
i => sql(s"CREATE TABLE db1.db1hivetable$i (key int, value string)")
46+
}
47+
}
48+
49+
override def afterAll(): Unit = {
50+
catalog.unregisterAllTables()
51+
(1 to 10).foreach {
52+
i => sql(s"DROP TABLE IF EXISTS hivetable$i")
53+
}
54+
(1 to 10).foreach {
55+
i => sql(s"DROP TABLE IF EXISTS db1.db1hivetable$i")
56+
}
57+
sql("DROP DATABASE IF EXISTS db1")
58+
}
59+
60+
test("get All Tables of current database") {
61+
// We are using default DB.
62+
val expectedTables =
63+
(1 to 10).map(i => Row(s"table$i", true)) ++
64+
(1 to 10).map(i => Row(s"hivetable$i", false))
65+
checkAnswer(tables(), expectedTables)
66+
}
67+
68+
test("getting All Tables with a database name has not impact on returned table names") {
69+
val expectedTables =
70+
// We are expecting to see Table1 to Table10 since there is no database associated with them.
71+
(1 to 10).map(i => Row(s"table$i", true)) ++
72+
(1 to 10).map(i => Row(s"db1temptable$i", true)) ++
73+
(1 to 10).map(i => Row(s"db1hivetable$i", false))
74+
checkAnswer(tables("db1"), expectedTables)
75+
}
76+
77+
test("query the returned DataFrame of tables") {
78+
val tableDF = tables()
79+
val schema = StructType(
80+
StructField("tableName", StringType, true) ::
81+
StructField("isTemporary", BooleanType, false) :: Nil)
82+
assert(schema === tableDF.schema)
83+
84+
checkAnswer(
85+
tableDF.filter("NOT isTemporary").select("tableName"),
86+
(1 to 10).map(i => Row(s"hivetable$i"))
87+
)
88+
89+
tableDF.registerTempTable("tables")
90+
checkAnswer(
91+
sql("SELECT isTemporary, tableName from tables WHERE isTemporary"),
92+
(1 to 10).map(i => Row(true, s"table$i"))
93+
)
94+
checkAnswer(
95+
tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
96+
Row("tables", true))
97+
dropTempTable("tables")
98+
}
99+
}

0 commit comments

Comments
 (0)