Skip to content

Commit d2e8b43

Browse files
Update the code as feedback
1 parent ca5e7f4 commit d2e8b43

File tree

9 files changed

+59
-41
lines changed

9 files changed

+59
-41
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -474,46 +474,60 @@ class Analyzer(
474474
object ImplicitGenerate extends Rule[LogicalPlan] {
475475
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
476476
case Project(Seq(Alias(g: Generator, name)), child) =>
477-
Generate(g, join = false, outer = false, child, qualifier = None, name :: Nil, Nil)
477+
Generate(g, join = false, outer = false,
478+
qualifier = None, UnresolvedAttribute(name) :: Nil, child)
478479
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
479-
Generate(g, join = false, outer = false, child, qualifier = None, names, Nil)
480+
Generate(g, join = false, outer = false,
481+
qualifier = None, names.map(UnresolvedAttribute(_)), child)
480482
}
481483
}
482484

485+
/**
486+
* Resolve the Generate, if the output names specified, we will take them, otherwise
487+
* we will try to provide the default names, which follow the same rule with Hive.
488+
*/
483489
object ResolveGenerate extends Rule[LogicalPlan] {
484490
// Construct the output attributes for the generator,
485491
// The output attribute names can be either specified or
486492
// auto generated.
487493
private def makeGeneratorOutput(
488494
generator: Generator,
489-
attributeNames: Seq[String],
490-
qualifier: Option[String]): Array[Attribute] = {
495+
generatorOutput: Seq[Attribute]): Seq[Attribute] = {
491496
val elementTypes = generator.elementTypes
492497

493-
val raw = if (attributeNames.size == elementTypes.size) {
494-
attributeNames.zip(elementTypes).map {
495-
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
498+
if (generatorOutput.size == elementTypes.size) {
499+
generatorOutput.zip(elementTypes).map {
500+
case (a, (t, nullable)) if !a.resolved =>
501+
AttributeReference(a.name, t, nullable)()
502+
case (a, _) => a
496503
}
497-
} else {
504+
} else if (generatorOutput.length == 0) {
498505
elementTypes.zipWithIndex.map {
499506
// keep the default column names as Hive does _c0, _c1, _cN
500507
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
501508
}
509+
} else {
510+
throw new AnalysisException(
511+
s"""
512+
|The number of aliases supplied in the AS clause does not match
513+
|the number of columns output by the UDTF expected
514+
|${elementTypes.size} aliases but got ${generatorOutput.size}
515+
""".stripMargin)
502516
}
503-
504-
qualifier.map(q => raw.map(_.withQualifiers(q :: Nil))).getOrElse(raw).toArray[Attribute]
505517
}
506518

507519
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
508520
case p: Generate if !p.child.resolved || !p.generator.resolved => p
509521
case p: Generate if p.resolved == false =>
510522
// if the generator output names are not specified, we will use the default ones.
511-
val gOutput = makeGeneratorOutput(p.generator, p.attributeNames, p.qualifier)
512523
Generate(
513-
p.generator, p.join, p.outer, p.child, p.qualifier, gOutput.map(_.name), gOutput)
524+
p.generator,
525+
join = p.join,
526+
outer = p.outer,
527+
p.qualifier,
528+
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
514529
}
515530
}
516-
517531
}
518532

519533
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ package object dsl {
289289
generator: Generator,
290290
join: Boolean = false,
291291
outer: Boolean = false,
292-
alias: Option[String] = None): Generate =
293-
Generate(generator, join, outer, logicalPlan, alias)
292+
alias: Option[String] = None): LogicalPlan =
293+
Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)
294294

295295
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
296296
InsertIntoTable(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,15 @@ abstract class Generator extends Expression {
4242

4343
override type EvaluatedType = TraversableOnce[Row]
4444

45-
override def dataType: DataType = ???
45+
// TODO ideally we should return the type of ArrayType(StructType),
46+
// however, we don't keep the output field names in the Generator.
47+
override def dataType: DataType = throw new UnsupportedOperationException
4648

4749
override def nullable: Boolean = false
4850

4951
/**
5052
* The output element data types in structure of Seq[(DataType, Nullable)]
53+
* TODO we probably need to add more information like metadata etc.
5154
*/
5255
def elementTypes: Seq[(DataType, Boolean)]
5356

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp
486486
if (pushDown.nonEmpty) {
487487
val pushDownPredicate = pushDown.reduce(And)
488488
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
489-
Filter(pushDownPredicate, g.child), g.qualifier, g.attributeNames, g.gOutput)
489+
g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
490490
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
491491
} else {
492492
filter

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,35 +45,35 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
4545
* it.
4646
* @param outer when true, each input row will be output at least once, even if the output of the
4747
* given `generator` is empty. `outer` has no effect when `join` is false.
48-
* @param child Children logical plan node
4948
* @param qualifier Qualifier for the attributes of generator(UDTF)
50-
* @param attributeNames the column names for the generator(UDTF), will be _c0, _c1 .. _cN if
51-
* leave as default (empty)
52-
* @param gOutput The output of Generator.
49+
* @param generatorOutput The output schema of the Generator.
50+
* @param child Children logical plan node
5351
*/
5452
case class Generate(
5553
generator: Generator,
5654
join: Boolean,
5755
outer: Boolean,
58-
child: LogicalPlan,
59-
qualifier: Option[String] = None,
60-
attributeNames: Seq[String] = Nil,
61-
gOutput: Seq[Attribute] = Nil)
56+
qualifier: Option[String],
57+
generatorOutput: Seq[Attribute],
58+
child: LogicalPlan)
6259
extends UnaryNode {
6360

6461
override lazy val resolved: Boolean = {
6562
generator.resolved &&
6663
childrenResolved &&
67-
attributeNames.length > 0 &&
68-
gOutput.map(_.name) == attributeNames
64+
!generatorOutput.exists(!_.resolved)
6965
}
7066

7167
// we don't want the gOutput to be taken as part of the expressions
7268
// as that will cause exceptions like unresolved attributes etc.
7369
override def expressions: Seq[Expression] = generator :: Nil
7470

7571
def output: Seq[Attribute] = {
76-
if (join) child.output ++ gOutput else gOutput
72+
val withoutQualifier = if (join) child.output ++ generatorOutput else generatorOutput
73+
74+
qualifier.map(q =>
75+
withoutQualifier.map(_.withQualifiers(q :: Nil))
76+
).getOrElse(withoutQualifier)
7777
}
7878
}
7979

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
3434
import org.apache.spark.rdd.RDD
3535
import org.apache.spark.storage.StorageLevel
3636
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
37-
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
37+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
3838
import org.apache.spark.sql.catalyst.expressions._
3939
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
4040
import org.apache.spark.sql.catalyst.plans.logical._
@@ -719,7 +719,8 @@ class DataFrame private[sql](
719719
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
720720
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))
721721

722-
Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)
722+
Generate(generator, join = true, outer = false,
723+
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
723724
}
724725

725726
/**
@@ -745,7 +746,8 @@ class DataFrame private[sql](
745746
}
746747
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)
747748

748-
Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)
749+
Generate(generator, join = true, outer = false,
750+
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
749751
}
750752

751753
/////////////////////////////////////////////////////////////////////////////

sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
2727
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
2828
* programming with one important additional feature, which allows the input rows to be joined with
2929
* their output.
30+
* @param generator the generator expression
3031
* @param join when true, each output row is implicitly joined with the input tuple that produced
3132
* it.
3233
* @param outer when true, each input row will be output at least once, even if the output of the
@@ -39,8 +40,8 @@ case class Generate(
3940
generator: Generator,
4041
join: Boolean,
4142
outer: Boolean,
42-
child: SparkPlan,
43-
output: Seq[Attribute])
43+
output: Seq[Attribute],
44+
child: SparkPlan)
4445
extends UnaryNode {
4546

4647
val boundGenerator = BindReferences.bindReference(generator, child.output)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
304304
execution.Except(planLater(left), planLater(right)) :: Nil
305305
case logical.Intersect(left, right) =>
306306
execution.Intersect(planLater(left), planLater(right)) :: Nil
307-
case g @ logical.Generate(generator, join, outer, child, _, _, _) =>
307+
case g @ logical.Generate(generator, join, outer, _, _, child) =>
308308
execution.Generate(
309-
generator, join = join, outer = outer, planLater(child), g.output) :: Nil
309+
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
310310
case logical.OneRowRelation =>
311311
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
312312
case logical.Repartition(expressions, child) =>

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -730,10 +730,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
730730
generator,
731731
join = true,
732732
outer = false,
733-
withWhere,
734733
Some(alias.toLowerCase),
735-
attributes,
736-
Nil)
734+
attributes.map(UnresolvedAttribute(_)),
735+
withWhere)
737736
}.getOrElse(withWhere)
738737

739738
// The projection of the query can either be a normal projection, an aggregation
@@ -841,10 +840,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
841840
generator,
842841
join = true,
843842
outer = isOuter.nonEmpty,
844-
nodeToRelation(relationClause),
845843
Some(alias.toLowerCase),
846-
attributes,
847-
Nil)
844+
attributes.map(UnresolvedAttribute(_)),
845+
nodeToRelation(relationClause))
848846

849847
/* All relations, possibly with aliases or sampling clauses. */
850848
case Token("TOK_TABREF", clauses) =>

0 commit comments

Comments
 (0)