Skip to content

Commit 19530da

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-11926][SQL] unify GetStructField and GetInternalRowField
Author: Wenchen Fan <wenchen@databricks.com> Closes #9909 from cloud-fan/get-struct.
1 parent 52bc25c commit 19530da

File tree

9 files changed

+21
-42
lines changed

9 files changed

+21
-42
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection {
130130

131131
/** Returns the current path with a field at ordinal extracted. */
132132
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
133-
.map(p => GetInternalRowField(p, ordinal, dataType))
133+
.map(p => GetStructField(p, ordinal))
134134
.getOrElse(BoundReference(ordinal, dataType, false))
135135

136136
/** Returns the current path or `BoundReference`. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
201201
if (attribute.isDefined) {
202202
// This target resolved to an attribute in child. It must be a struct. Expand it.
203203
attribute.get.dataType match {
204-
case s: StructType => {
205-
s.fields.map( f => {
206-
val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get)
204+
case s: StructType => s.zipWithIndex.map {
205+
case (f, i) =>
206+
val extract = GetStructField(attribute.get, i)
207207
Alias(extract, target.get + "." + f.name)()
208-
})
209208
}
209+
210210
case _ => {
211211
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
212212
target.get + "`")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ object ExpressionEncoder {
111111
case UnresolvedAttribute(nameParts) =>
112112
assert(nameParts.length == 1)
113113
UnresolvedExtractValue(input, Literal(nameParts.head))
114-
case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
114+
case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
115115
}
116116
}
117117
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ object RowEncoder {
220220
If(
221221
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
222222
Literal.create(null, externalDataTypeFor(f.dataType)),
223-
constructorFor(GetInternalRowField(input, i, f.dataType)))
223+
constructorFor(GetStructField(input, i)))
224224
}
225225
CreateExternalRow(convertedFields)
226226
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] {
206206
*/
207207
def prettyString: String = {
208208
transform {
209-
case a: AttributeReference => PrettyAttribute(a.name)
209+
case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
210210
case u: UnresolvedAttribute => PrettyAttribute(u.name)
211211
}.toString
212212
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ object ExtractValue {
5151
case (StructType(fields), NonNullLiteral(v, StringType)) =>
5252
val fieldName = v.toString
5353
val ordinal = findField(fields, fieldName, resolver)
54-
GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
54+
GetStructField(child, ordinal, Some(fieldName))
5555

5656
case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
5757
val fieldName = v.toString
@@ -97,18 +97,18 @@ object ExtractValue {
9797
* Returns the value of fields in the Struct `child`.
9898
*
9999
* No need to do type checking since it is handled by [[ExtractValue]].
100-
* TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
100+
*
101+
* Note that we can pass in the field name directly to keep case preserving in `toString`.
102+
* For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
101103
*/
102-
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
104+
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
103105
extends UnaryExpression {
104106

105-
override def dataType: DataType = child.dataType match {
106-
case s: StructType => s(ordinal).dataType
107-
// This is a hack to avoid breaking existing code until we remove the need for the struct field
108-
case _ => field.dataType
109-
}
107+
private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)
108+
109+
override def dataType: DataType = field.dataType
110110
override def nullable: Boolean = child.nullable || field.nullable
111-
override def toString: String = s"$child.${field.name}"
111+
override def toString: String = s"$child.${name.getOrElse(field.name)}"
112112

113113
protected override def nullSafeEval(input: Any): Any =
114114
input.asInstanceOf[InternalRow].get(ordinal, field.dataType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ case class AttributeReference(
273273
* A place holder used when printing expressions without debugging information such as the
274274
* expression id or the unresolved indicator.
275275
*/
276-
case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
276+
case class PrettyAttribute(name: String, dataType: DataType = NullType)
277+
extends Attribute with Unevaluable {
277278

278279
override def toString: String = name
279280

@@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
286287
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
287288
override def exprId: ExprId = throw new UnsupportedOperationException
288289
override def nullable: Boolean = throw new UnsupportedOperationException
289-
override def dataType: DataType = NullType
290290
}
291291

292292
object VirtualColumn {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
517517
}
518518
}
519519

520-
case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
521-
extends UnaryExpression {
522-
523-
override def nullable: Boolean = true
524-
525-
override def eval(input: InternalRow): Any =
526-
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
527-
528-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
529-
nullSafeCodeGen(ctx, ev, eval => {
530-
s"""
531-
if ($eval.isNullAt($ordinal)) {
532-
${ev.isNull} = true;
533-
} else {
534-
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
535-
}
536-
"""
537-
})
538-
}
539-
}
540-
541520
/**
542521
* Serializes an input object using a generic serializer (Kryo or Java).
543522
* @param kryo if true, use Kryo. Otherwise, use Java.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
7979
def getStructField(expr: Expression, fieldName: String): GetStructField = {
8080
expr.dataType match {
8181
case StructType(fields) =>
82-
val field = fields.find(_.name == fieldName).get
83-
GetStructField(expr, field, fields.indexOf(field))
82+
val index = fields.indexWhere(_.name == fieldName)
83+
GetStructField(expr, index)
8484
}
8585
}
8686

0 commit comments

Comments
 (0)