Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49966][SQL] Codegen Support for JsonToStructs(from_json) #48466

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.json

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, PermissiveMode}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, VariantType}
import org.apache.spark.unsafe.types.UTF8String

class JsonToStructsEvaluator(
options: Map[String, String],
nullableSchema: DataType,
nameOfCorruptRecord: String,
timeZoneId: Option[String],
variantAllowDuplicateKeys: Boolean) extends Serializable {

// This converts parsed rows to the desired output by the given schema.
@transient
private lazy val converter = nullableSchema match {
case _: StructType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
case _: ArrayType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null
case _: MapType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null
}

@transient
private lazy val parser = {
val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord)
val mode = parsedOptions.parseMode
if (mode != PermissiveMode && mode != FailFastMode) {
throw QueryCompilationErrors.parseModeUnsupportedError("from_json", mode)
}
val (parserSchema, actualSchema) = nullableSchema match {
case s: StructType =>
ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord)
(s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)))
case other =>
(StructType(Array(StructField("value", other))), other)
}

val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false)
val createParser = CreateJacksonParser.utf8String _

new FailureSafeParser[UTF8String](
input => rawParser.parse(input, createParser, identity[UTF8String]),
mode,
parserSchema,
parsedOptions.columnNameOfCorruptRecord)
}

final def evaluate(json: UTF8String): Any = {
if (json == null) return null
nullableSchema match {
case _: VariantType =>
VariantExpressionEvalUtils.parseJson(json,
allowDuplicateKeys = variantAllowDuplicateKeys)
case _ =>
converter(parser.parse(json))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.json.JsonExpressionUtils
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, JsonToStructsEvaluator}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern}
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -639,15 +638,14 @@ case class JsonToStructs(
variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS))
extends UnaryExpression
with TimeZoneAwareExpression
with CodegenFallback
with ExpectsInputTypes
with NullIntolerant
with QueryErrorsBase {

// The JSON input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
// can generate incorrect files if values are missing in columns declared as non-nullable.
val nullableSchema = schema.asNullable
private val nullableSchema: DataType = schema.asNullable

override def nullable: Boolean = true

Expand Down Expand Up @@ -680,53 +678,36 @@ case class JsonToStructs(
messageParameters = Map("schema" -> toSQLType(nullableSchema)))
}

// This converts parsed rows to the desired output by the given schema.
@transient
lazy val converter = nullableSchema match {
case _: StructType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
case _: ArrayType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null
case _: MapType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null
}

val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
@transient lazy val parser = {
val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord)
val mode = parsedOptions.parseMode
if (mode != PermissiveMode && mode != FailFastMode) {
throw QueryCompilationErrors.parseModeUnsupportedError("from_json", mode)
}
val (parserSchema, actualSchema) = nullableSchema match {
case s: StructType =>
ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord)
(s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)))
case other =>
(StructType(Array(StructField("value", other))), other)
}

val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false)
val createParser = CreateJacksonParser.utf8String _

new FailureSafeParser[UTF8String](
input => rawParser.parse(input, createParser, identity[UTF8String]),
mode,
parserSchema,
parsedOptions.columnNameOfCorruptRecord)
}

override def dataType: DataType = nullableSchema

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(json: Any): Any = nullableSchema match {
case _: VariantType =>
VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String],
allowDuplicateKeys = variantAllowDuplicateKeys)
case _ =>
converter(parser.parse(json.asInstanceOf[UTF8String]))
@transient
private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)

@transient
private lazy val evaluator = new JsonToStructsEvaluator(
options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys)

override def nullSafeEval(json: Any): Any = evaluator.evaluate(json.asInstanceOf[UTF8String])

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to use Invoke with Literal(new JsonToStructsEvaluator(...), ObjectType(...)) to rewrite the expression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me investigate it.

val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
val eval = child.genCode(ctx)
val resultType = CodeGenerator.boxedType(dataType)
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
|${eval.code}
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(
| ${eval.isNull} ? null : ${eval.value});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this check? seems like evaluate() does this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's redundant. I have already removed it.

Copy link
Contributor Author

@panbingkun panbingkun Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code generated by an example is roughly like this:

  • Before
/* 031 */       boolean localtablescan_isNull_0 = localtablescan_row_0.isNullAt(0);
/* 032 */       UTF8String localtablescan_value_0 = localtablescan_isNull_0 ?
/* 033 */       null : (localtablescan_row_0.getUTF8String(0));
/* 034 */       InternalRow project_result_0 = (InternalRow) ((org.apache.spark.sql.catalyst.expressions.json.JsonToStructsEvaluator) references[1] /* evaluator */).evaluate(
/* 035 */         localtablescan_isNull_0 ? null : localtablescan_value_0);
/* 036 */       boolean project_isNull_0 = project_result_0 == null;
/* 037 */       InternalRow project_value_0 = null;
/* 038 */       if (!project_isNull_0) {
/* 039 */         project_value_0 = project_result_0;
/* 040 */       }
/* 041 */       project_mutableStateArray_0[0].reset();
/* 042 */
/* 043 */       project_mutableStateArray_0[0].zeroOutNullBytes();
  • After
/* 031 */       boolean localtablescan_isNull_0 = localtablescan_row_0.isNullAt(0);
/* 032 */       UTF8String localtablescan_value_0 = localtablescan_isNull_0 ?
/* 033 */       null : (localtablescan_row_0.getUTF8String(0));
/* 034 */       InternalRow project_result_0 = (InternalRow) ((org.apache.spark.sql.catalyst.expressions.json.JsonToStructsEvaluator) references[1] /* evaluator */).evaluate(localtablescan_value_0);
/* 035 */       boolean project_isNull_0 = project_result_0 == null;
/* 036 */       InternalRow project_value_0 = null;
/* 037 */       if (!project_isNull_0) {
/* 038 */         project_value_0 = project_result_0;
/* 039 */       }
/* 040 */       project_mutableStateArray_0[0].reset();
/* 041 */
/* 042 */       project_mutableStateArray_0[0].zeroOutNullBytes();
  • Obviously unnecessary
image
  • So it has been removed.

|boolean ${ev.isNull} = $resultTerm == null;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
|""".stripMargin)
}

override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil
Expand Down