Skip to content

[SPARK-11926][SQL] unify GetStructField and GetInternalRowField #9909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection {

/** Returns the current path with a field at ordinal extracted. */
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
.map(p => GetInternalRowField(p, ordinal, dataType))
.map(p => GetStructField(p, ordinal))
.getOrElse(BoundReference(ordinal, dataType, false))

/** Returns the current path or `BoundReference`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
if (attribute.isDefined) {
// This target resolved to an attribute in child. It must be a struct. Expand it.
attribute.get.dataType match {
case s: StructType => {
s.fields.map( f => {
val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get)
case s: StructType => s.zipWithIndex.map {
case (f, i) =>
val extract = GetStructField(attribute.get, i)
Alias(extract, target.get + "." + f.name)()
})
}

case _ => {
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
target.get + "`")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ object ExpressionEncoder {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object RowEncoder {
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
constructorFor(GetInternalRowField(input, i, f.dataType)))
constructorFor(GetStructField(input, i)))
}
CreateExternalRow(convertedFields)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] {
*/
def prettyString: String = {
transform {
case a: AttributeReference => PrettyAttribute(a.name)
case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object ExtractValue {
case (StructType(fields), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
GetStructField(child, ordinal, Some(fieldName))

case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
Expand Down Expand Up @@ -97,18 +97,18 @@ object ExtractValue {
* Returns the value of fields in the Struct `child`.
*
* No need to do type checking since it is handled by [[ExtractValue]].
* TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
*
* Note that we can pass in the field name directly to keep case preserving in `toString`.
* For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe comment that the name is only for debugging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not for debugging, but to keep case preserving. For example, if we try to get field yEAr from <year: int, month: int>, we should pass yEAr to GetStructField. And it's also consistent with the original behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, then document that :)

extends UnaryExpression {

override def dataType: DataType = child.dataType match {
case s: StructType => s(ordinal).dataType
// This is a hack to avoid breaking existing code until we remove the need for the struct field
case _ => field.dataType
}
private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)

override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
override def toString: String = s"$child.${field.name}"
override def toString: String = s"$child.${name.getOrElse(field.name)}"

protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow].get(ordinal, field.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ case class AttributeReference(
* A place holder used when printing expressions without debugging information such as the
* expression id or the unresolved indicator.
*/
case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
case class PrettyAttribute(name: String, dataType: DataType = NullType)
extends Attribute with Unevaluable {

override def toString: String = name

Expand All @@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
override def nullable: Boolean = throw new UnsupportedOperationException
override def dataType: DataType = NullType
}

object VirtualColumn {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
}
}

case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
extends UnaryExpression {

override def nullable: Boolean = true

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
}
"""
})
}
}

/**
* Serializes an input object using a generic serializer (Kryo or Java).
* @param kryo if true, use Kryo. Otherwise, use Java.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
def getStructField(expr: Expression, fieldName: String): GetStructField = {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
GetStructField(expr, field, fields.indexOf(field))
val index = fields.indexWhere(_.name == fieldName)
GetStructField(expr, index)
}
}

Expand Down