Skip to content

[SPARK-3414][SQL] Replace LowerCaseSchema with Resolver #2382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true
class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean) class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean)
extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {


val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution

// TODO: pass this in as a parameter. // TODO: pass this in as a parameter.
val fixedPoint = FixedPoint(100) val fixedPoint = FixedPoint(100)


Expand All @@ -48,8 +50,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
lazy val batches: Seq[Batch] = Seq( lazy val batches: Seq[Batch] = Seq(
Batch("MultiInstanceRelations", Once, Batch("MultiInstanceRelations", Once,
NewRelationInstances), NewRelationInstances),
Batch("CaseInsensitiveAttributeReferences", Once,
(if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*),
Batch("Resolution", fixedPoint, Batch("Resolution", fixedPoint,
ResolveReferences :: ResolveReferences ::
ResolveRelations :: ResolveRelations ::
Expand Down Expand Up @@ -98,23 +98,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
} }
} }


/**
* Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase.
*/
object LowercaseAttributeReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case UnresolvedRelation(databaseName, name, alias) =>
UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase))
case Subquery(alias, child) => Subquery(alias.toLowerCase, child)
case q: LogicalPlan => q transformExpressions {
case s: Star => s.copy(table = s.table.map(_.toLowerCase))
case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase)
case Alias(c, name) => Alias(c, name.toLowerCase)()
case GetField(c, name) => GetField(c, name.toLowerCase)
}
}
}

/** /**
* Replaces [[UnresolvedAttribute]]s with concrete * Replaces [[UnresolvedAttribute]]s with concrete
* [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's
Expand All @@ -127,7 +110,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
q transformExpressions { q transformExpressions {
case u @ UnresolvedAttribute(name) => case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round. // Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result = q.resolveChildren(name).getOrElse(u) val result = q.resolveChildren(name, resolver).getOrElse(u)
logDebug(s"Resolving $u to $result") logDebug(s"Resolving $u to $result")
result result
} }
Expand All @@ -144,7 +127,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val resolved = unresolved.flatMap(child.resolveChildren) val resolved = unresolved.flatMap(child.resolve(_, resolver))
val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })


val missingInProject = requiredAttributes -- p.output val missingInProject = requiredAttributes -- p.output
Expand All @@ -154,6 +137,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
Sort(ordering, Sort(ordering,
Project(projectList ++ missingInProject, child))) Project(projectList ++ missingInProject, child)))
} else { } else {
logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan. s // Nothing we can do here. Return original plan.
} }
case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
Expand All @@ -165,7 +149,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
) )


logDebug(s"Grouping expressions: $groupingRelation") logDebug(s"Grouping expressions: $groupingRelation")
val resolved = unresolved.flatMap(groupingRelation.resolve) val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver))
val missingInAggs = resolved.filterNot(a.outputSet.contains) val missingInAggs = resolved.filterNot(a.outputSet.contains)
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
if (missingInAggs.nonEmpty) { if (missingInAggs.nonEmpty) {
Expand Down Expand Up @@ -258,22 +242,22 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case p @ Project(projectList, child) if containsStar(projectList) => case p @ Project(projectList, child) if containsStar(projectList) =>
Project( Project(
projectList.flatMap { projectList.flatMap {
case s: Star => s.expand(child.output) case s: Star => s.expand(child.output, resolver)
case o => o :: Nil case o => o :: Nil
}, },
child) child)
case t: ScriptTransformation if containsStar(t.input) => case t: ScriptTransformation if containsStar(t.input) =>
t.copy( t.copy(
input = t.input.flatMap { input = t.input.flatMap {
case s: Star => s.expand(t.child.output) case s: Star => s.expand(t.child.output, resolver)
case o => o :: Nil case o => o :: Nil
} }
) )
// If the aggregate function argument contains Stars, expand it. // If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) => case a: Aggregate if containsStar(a.aggregateExpressions) =>
a.copy( a.copy(
aggregateExpressions = a.aggregateExpressions.flatMap { aggregateExpressions = a.aggregateExpressions.flatMap {
case s: Star => s.expand(a.child.output) case s: Star => s.expand(a.child.output, resolver)
case o => o :: Nil case o => o :: Nil
} }
) )
Expand All @@ -290,13 +274,11 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
/** /**
* Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are
* only required to provide scoping information for attributes and can be removed once analysis is * only required to provide scoping information for attributes and can be removed once analysis is
* complete. Similarly, this node also removes * complete.
* [[catalyst.plans.logical.LowerCaseSchema LowerCaseSchema]] operators.
*/ */
object EliminateAnalysisOperators extends Rule[LogicalPlan] { object EliminateAnalysisOperators extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform { def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Subquery(_, child) => child case Subquery(_, child) => child
case LowerCaseSchema(child) => child
} }
} }


Original file line number Original file line Diff line number Diff line change
Expand Up @@ -22,4 +22,14 @@ package org.apache.spark.sql.catalyst
* Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s * Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s
* into fully typed objects using information in a schema [[Catalog]]. * into fully typed objects using information in a schema [[Catalog]].
*/ */
package object analysis package object analysis {

/**
* Responsible for resolving which identifiers refer to the same entity. For example, by using
* case insensitive equality.
*/
type Resolver = (String, String) => Boolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolver probably a general name, can we use a more precise name for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will actually end up providing more general resolution functionality in the long term. I've added some scala doc for clarity though.


val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b)
val caseSensitiveResolution = (a: String, b: String) => a == b
}
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def newInstance = this override def newInstance = this
override def withNullability(newNullability: Boolean) = this override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this override def withQualifiers(newQualifiers: Seq[String]) = this
override def withName(newName: String) = UnresolvedAttribute(name)


// Unresolved attributes are transient at compile time and don't get evaluated during execution. // Unresolved attributes are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType = override def eval(input: Row = null): EvaluatedType =
Expand Down Expand Up @@ -97,13 +98,14 @@ case class Star(
override def newInstance = this override def newInstance = this
override def withNullability(newNullability: Boolean) = this override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this override def withQualifiers(newQualifiers: Seq[String]) = this
override def withName(newName: String) = this


def expand(input: Seq[Attribute]): Seq[NamedExpression] = { def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
val expandedAttributes: Seq[Attribute] = table match { val expandedAttributes: Seq[Attribute] = table match {
// If there is no table specified, use all input attributes. // If there is no table specified, use all input attributes.
case None => input case None => input
// If there is a table, pick out attributes that are part of this table. // If there is a table, pick out attributes that are part of this table.
case Some(t) => input.filter(_.qualifiers contains t) case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty)
} }
val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map { val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map {
case (n: NamedExpression, _) => n case (n: NamedExpression, _) => n
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ abstract class Attribute extends NamedExpression {


def withNullability(newNullability: Boolean): Attribute def withNullability(newNullability: Boolean): Attribute
def withQualifiers(newQualifiers: Seq[String]): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute
def withName(newName: String): Attribute


def toAttribute = this def toAttribute = this
def newInstance: Attribute def newInstance: Attribute
Expand Down Expand Up @@ -86,7 +87,6 @@ case class Alias(child: Expression, name: String)
override def dataType = child.dataType override def dataType = child.dataType
override def nullable = child.nullable override def nullable = child.nullable



override def toAttribute = { override def toAttribute = {
if (resolved) { if (resolved) {
AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers)
Expand Down Expand Up @@ -144,6 +144,14 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
} }
} }


override def withName(newName: String): AttributeReference = {
if (name == newName) {
this
} else {
AttributeReference(newName, dataType, nullable)(exprId, qualifiers)
}
}

/** /**
* Returns a copy of this [[AttributeReference]] with new qualifiers. * Returns a copy of this [[AttributeReference]] with new qualifiers.
*/ */
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@


package org.apache.spark.sql.catalyst.plans.logical package org.apache.spark.sql.catalyst.plans.logical


import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees


abstract class LogicalPlan extends QueryPlan[LogicalPlan] { abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
self: Product => self: Product =>


/** /**
Expand Down Expand Up @@ -75,42 +77,95 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
* nodes of this LogicalPlan. The attribute is expressed as * nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`. * as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/ */
def resolveChildren(name: String): Option[NamedExpression] = def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, children.flatMap(_.output)) resolve(name, children.flatMap(_.output), resolver)


/** /**
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this * Optionally resolves the given string to a [[NamedExpression]] based on the output of this
* LogicalPlan. The attribute is expressed as string in the following form: * LogicalPlan. The attribute is expressed as string in the following form:
* `[scope].AttributeName.[nested].[fields]...`. * `[scope].AttributeName.[nested].[fields]...`.
*/ */
def resolve(name: String): Option[NamedExpression] = def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, output) resolve(name, output, resolver)


/** Performs attribute resolution given a name and a sequence of possible attributes. */ /** Performs attribute resolution given a name and a sequence of possible attributes. */
protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { protected def resolve(
name: String,
input: Seq[Attribute],
resolver: Resolver): Option[NamedExpression] = {

val parts = name.split("\\.") val parts = name.split("\\.")

// Collect all attributes that are output by this nodes children where either the first part // Collect all attributes that are output by this nodes children where either the first part
// matches the name or where the first part matches the scope and the second part matches the // matches the name or where the first part matches the scope and the second part matches the
// name. Return these matches along with any remaining parts, which represent dotted access to // name. Return these matches along with any remaining parts, which represent dotted access to
// struct fields. // struct fields.
val options = input.flatMap { option => val options = input.flatMap { option =>
// If the first part of the desired name matches a qualifier for this possible match, drop it. // If the first part of the desired name matches a qualifier for this possible match, drop it.
val remainingParts = val remainingParts =
if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts if (option.qualifiers.find(resolver(_, parts.head)).nonEmpty && parts.size > 1) {
if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil parts.drop(1)
} else {
parts
}

if (resolver(option.name, remainingParts.head)) {
// Preserve the case of the user's attribute reference.
(option.withName(remainingParts.head), remainingParts.tail.toList) :: Nil
} else {
Nil
}
} }


options.distinct match { options.distinct match {
case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. // One match, no nested fields, use it.
case Seq((a, Nil)) => Some(a)

// One match, but we also need to extract the requested nested field. // One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) => case Seq((a, nestedFields)) =>
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) val aliased =
case Seq() => None // No matches. Alias(
resolveNesting(nestedFields, a, resolver),
nestedFields.last)() // Preserve the case of the user's field access.
Some(aliased)

// No matches.
case Seq() =>
logTrace(s"Could not find $name in ${input.mkString(", ")}")
None

// More than one match.
case ambiguousReferences => case ambiguousReferences =>
throw new TreeNodeException( throw new TreeNodeException(
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
} }
} }

/**
* Given a list of successive nested field accesses, and a based expression, attempt to resolve
* the actual field lookups on this expression.
*/
private def resolveNesting(
nestedFields: List[String],
expression: Expression,
resolver: Resolver): Expression = {

(nestedFields, expression.dataType) match {
case (Nil, _) => expression
case (requestedField :: rest, StructType(fields)) =>
val actualField = fields.filter(f => resolver(f.name, requestedField))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a problem here. Currently a.b[1].c.d will be parsed as GetField(GetField(GetItem(Unresolved("a.b"), 1), "c"), "d") , so the case-sensitive-check only happens when resolve Unresolved("a.b") to GetField(Attribute("a"), "b"). Something like "SELECT a[0].A.A from nested" will fail for hql on case-sensitive-check.
I think we should also do this check in GetField.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point. Right now you can't make a SQLContext case insensitive, but when you can this will be problem. Maybe you should note this on SPARK-3617

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, sorry... Is that how the HiveQL parser will do it too? I'm not a huge fan of moving resolution logic into the expressions. What about a rule that only ran in case insensitive mode that fixes unresolved GetFields?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this bug exists in HiveQL. I have opened a PR to fix this(adding a rule to fix unresolved GetFields).#2543 Need your comments :)

actualField match {
case Seq() =>
sys.error(
s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}")
case Seq(singleMatch) =>
resolveNesting(rest, GetField(expression, singleMatch.name), resolver)
case multipleMatches =>
sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}")
}
case (_, dt) => sys.error(s"Can't access nested field in type $dt")
}
}
} }


/** /**
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -154,32 +154,6 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output = child.output.map(_.withQualifiers(alias :: Nil)) override def output = child.output.map(_.withQualifiers(alias :: Nil))
} }


/**
* Converts the schema of `child` to all lowercase, together with LowercaseAttributeReferences
* this allows for optional case insensitive attribute resolution. This node can be elided after
* analysis.
*/
case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
protected def lowerCaseSchema(dataType: DataType): DataType = dataType match {
case StructType(fields) =>
StructType(fields.map(f =>
StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable)))
case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull)
case otherType => otherType
}

override val output = child.output.map {
case a: AttributeReference =>
AttributeReference(
a.name.toLowerCase,
lowerCaseSchema(a.dataType),
a.nullable)(
a.exprId,
a.qualifiers)
case other => other
}
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode { extends UnaryNode {


Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf * @group userf
*/ */
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
catalog.registerTable(None, tableName, rdd.queryExecution.analyzed) catalog.registerTable(None, tableName, rdd.queryExecution.logical)
} }


/** /**
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
} }


test("SPARK-3349 partitioning after limit") { test("SPARK-3349 partitioning after limit") {
/*
sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
.limit(2) .limit(2)
.registerTempTable("subset1") .registerTempTable("subset1")
Expand All @@ -395,7 +394,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"),
(1, "a", 1) :: (1, "a", 1) ::
(2, "b", 2) :: Nil) (2, "b", 2) :: Nil)
*/
} }


test("mixed-case keywords") { test("mixed-case keywords") {
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -244,15 +244,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {


/* A catalyst metadata catalog that points to the Hive Metastore. */ /* A catalyst metadata catalog that points to the Hive Metastore. */
@transient @transient
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog
override def lookupRelation(
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan = {

LowerCaseSchema(super.lookupRelation(databaseName, tableName, alias))
}
}


// Note that HiveUDFs will be overridden by functions registered in this context. // Note that HiveUDFs will be overridden by functions registered in this context.
@transient @transient
Expand Down
Loading