Skip to content

[SPARK-3299][SQL]Public API in SQLContext to list tables #4547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,40 @@ def table(self, tableName):
"""
return DataFrame(self._ssql_ctx.table(tableName), self)

def tables(self, dbName=None):
"""Returns a DataFrame containing names of tables in the given database.

If `dbName` is not specified, the current database will be used.

The returned DataFrame has two columns, tableName and isTemporary
(a column with BooleanType indicating if a table is a temporary one or not).

>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.tables()
>>> df2.filter("tableName = 'table1'").first()
Row(tableName=u'table1', isTemporary=True)
"""
if dbName is None:
return DataFrame(self._ssql_ctx.tables(), self)
else:
return DataFrame(self._ssql_ctx.tables(dbName), self)

def tableNames(self, dbName=None):
"""Returns a list of names of tables in the database `dbName`.

If `dbName` is not specified, the current database will be used.

>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> "table1" in sqlCtx.tableNames()
True
>>> "table1" in sqlCtx.tableNames("db")
True
"""
if dbName is None:
return [name for name in self._ssql_ctx.tableNames()]
else:
return [name for name in self._ssql_ctx.tableNames(dbName)]

def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
self._ssql_ctx.cacheTable(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ trait Catalog {
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan

/**
* Returns tuples of (tableName, isTemporary) for all tables in the given database.
* isTemporary is a Boolean value indicates if a table is a temporary or not.
*/
def getTables(databaseName: Option[String]): Seq[(String, Boolean)]

def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit

def unregisterTable(tableIdentifier: Seq[String]): Unit
Expand Down Expand Up @@ -101,6 +107,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
// properly qualified with this alias.
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}

override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
tables.map {
case (name, _) => (name, true)
}.toSeq
}
}

/**
Expand Down Expand Up @@ -137,6 +149,27 @@ trait OverrideCatalog extends Catalog {
withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
}

abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val dbName = if (!caseSensitive) {
if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
} else {
databaseName
}

val temporaryTables = overrides.filter {
// If a temporary table does not have an associated database, we should return its name.
case ((None, _), _) => true
// If a temporary table does have an associated database, we should return it if the database
// matches the given database name.
case ((db: Some[String], _), _) if db == dbName => true
case _ => false
}.map {
case ((_, tableName), _) => (tableName, true)
}.toSeq

temporaryTables ++ super.getTables(databaseName)
}

override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
Expand Down Expand Up @@ -172,6 +205,10 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}

override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
throw new UnsupportedOperationException
}

def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
Expand Down
36 changes: 36 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,42 @@ class SQLContext(@transient val sparkContext: SparkContext)
def table(tableName: String): DataFrame =
DataFrame(this, catalog.lookupRelation(Seq(tableName)))

/**
* Returns a [[DataFrame]] containing names of existing tables in the given database.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the javadoc is actually wrong. this one should say "current database", and next one should say "given database"

* The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd just say a boolean

* indicating if a table is a temporary one or not).
*/
def tables(): DataFrame = {
createDataFrame(catalog.getTables(None)).toDataFrame("tableName", "isTemporary")
}

/**
* Returns a [[DataFrame]] containing names of existing tables in the current database.
* The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
* indicating if a table is a temporary one or not).
*/
def tables(databaseName: String): DataFrame = {
createDataFrame(catalog.getTables(Some(databaseName))).toDataFrame("tableName", "isTemporary")
}

/**
* Returns an array of names of tables in the current database.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns the names of all tables in the current database as an array?

*/
def tableNames(): Array[String] = {
catalog.getTables(None).map {
case (tableName, _) => tableName
}.toArray
}

/**
* Returns an array of names of tables in the given database.
*/
def tableNames(databaseName: String): Array[String] = {
catalog.getTables(Some(databaseName)).map {
case (tableName, _) => tableName
}.toArray
}

protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext

Expand Down
76 changes: 76 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}

class ListTablesSuite extends QueryTest with BeforeAndAfter {

import org.apache.spark.sql.test.TestSQLContext.implicits._

val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")

before {
df.registerTempTable("ListTablesSuiteTable")
}

after {
catalog.unregisterTable(Seq("ListTablesSuiteTable"))
}

test("get all tables") {
checkAnswer(
tables().filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))

catalog.unregisterTable(Seq("ListTablesSuiteTable"))
assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}

test("getting all Tables with a database name has no impact on returned table names") {
checkAnswer(
tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))

catalog.unregisterTable(Seq("ListTablesSuiteTable"))
assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}

test("query the returned DataFrame of tables") {
val tableDF = tables()
val schema = StructType(
StructField("tableName", StringType, true) ::
StructField("isTemporary", BooleanType, false) :: Nil)
assert(schema === tableDF.schema)

tableDF.registerTempTable("tables")
checkAnswer(
sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
Row(true, "ListTablesSuiteTable")
)
checkAnswer(
tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
dropTempTable("tables")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
}
}

override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val dbName = databaseName.getOrElse(hive.sessionState.getCurrentDatabase)
client.getAllTables(dbName).map(tableName => (tableName, false))
}

/**
* Create table with specified database, table name, table description and schema
* @param databaseName Database Name
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.Row

class ListTablesSuite extends QueryTest with BeforeAndAfterAll {

import org.apache.spark.sql.hive.test.TestHive.implicits._

val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")

override def beforeAll(): Unit = {
// The catalog in HiveContext is a case insensitive one.
catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan)
catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan)
sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
}

override def afterAll(): Unit = {
catalog.unregisterTable(Seq("ListTablesSuiteTable"))
catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"))
sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
}

test("get all tables of current database") {
val allTables = tables()
// We are using default DB.
checkAnswer(
allTables.filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0)
checkAnswer(
allTables.filter("tableName = 'hivelisttablessuitetable'"),
Row("hivelisttablessuitetable", false))
assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0)
}

test("getting all tables with a database name") {
val allTables = tables("ListTablesSuiteDB")
checkAnswer(
allTables.filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
allTables.filter("tableName = 'indblisttablessuitetable'"),
Row("indblisttablessuitetable", true))
assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
checkAnswer(
allTables.filter("tableName = 'hiveindblisttablessuitetable'"),
Row("hiveindblisttablessuitetable", false))
}
}