Skip to content

[SPARK-16853][SQL] fixes encoder error in DataSet typed select #14474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ object MimaExcludes {
lazy val v21excludes = v20excludes ++ {
Seq(
// [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"),
// [SPARK-16853][SQL] Fixes encoder error in DataSet typed select
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select")
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ object ExpressionEncoder {
ClassTag(cls))
}

// Tuple1
def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]

def tuple[T1, T2](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
Expand Down
20 changes: 11 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1061,15 +1061,17 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
new Dataset[U1](
sparkSession,
Project(
c1.withInputType(
exprEnc.deserializer,
logicalPlan.output).named :: Nil,
logicalPlan),
implicitly[Encoder[U1]])
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
logicalPlan)

if (encoder.flat) {
new Dataset[U1](sparkSession, project, encoder)
} else {
// Flattens inner fields of U1
new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1)
}
}

/**
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}

test("SPARK-16853: select, case class and tuple") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about typed select that returns case class or tuple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan

This follows existing test case name, for example, for select 2, current name is

test("select 2, primitive and tuple") {

https://github.com/apache/spark/pull/14474/files#diff-3836bd1fc5ae9c8b8ac4bdb2b9944159L196

val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(
ds.select(expr("struct(_2, _2)").as[(Int, Int)]): Dataset[(Int, Int)],
(1, 1), (2, 2), (3, 3))

checkDataset(
ds.select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]): Dataset[ClassData],
ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
}

test("select 2") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(
Expand Down