Skip to content

Commit efa824a

Browse files
HyukjinKwonJackey Lee
authored andcommitted
[SPARK-24709][SQL][FOLLOW-UP] Make schema_of_json's input json as literal only
## What changes were proposed in this pull request? The main purpose of `schema_of_json` is the usage of combination with `from_json` (to make up the leak of schema inference) which takes its schema only as literal; however, currently `schema_of_json` allows JSON input as non-literal expressions (e.g, column). This was mistakenly allowed - we don't have to take other usages rather then the main purpose into account for now. This PR makes a followup to only allow literals for `schema_of_json`'s JSON input. We can allow non literal expressions later when it's needed or there are some usecase for it. ## How was this patch tested? Unit tests were added. Closes apache#22775 from HyukjinKwon/SPARK-25447-followup. Lead-authored-by: hyukjinkwon <gurwls223@apache.org> Co-authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 5392fe5 commit efa824a

File tree

6 files changed

+87
-24
lines changed

6 files changed

+87
-24
lines changed

python/pyspark/sql/functions.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2335,30 +2335,32 @@ def to_json(col, options={}):
23352335

23362336
@ignore_unicode_prefix
23372337
@since(2.4)
2338-
def schema_of_json(col, options={}):
2338+
def schema_of_json(json, options={}):
23392339
"""
2340-
Parses a column containing a JSON string and infers its schema in DDL format.
2340+
Parses a JSON string and infers its schema in DDL format.
23412341
2342-
:param col: string column in json format
2342+
:param json: a JSON string or a string literal containing a JSON string.
23432343
:param options: options to control parsing. accepts the same options as the JSON datasource
23442344
23452345
.. versionchanged:: 3.0
23462346
It accepts `options` parameter to control schema inferring.
23472347
2348-
>>> from pyspark.sql.types import *
2349-
>>> data = [(1, '{"a": 1}')]
2350-
>>> df = spark.createDataFrame(data, ("key", "value"))
2351-
>>> df.select(schema_of_json(df.value).alias("json")).collect()
2352-
[Row(json=u'struct<a:bigint>')]
2348+
>>> df = spark.range(1)
23532349
>>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect()
23542350
[Row(json=u'struct<a:bigint>')]
2355-
>>> schema = schema_of_json(lit('{a: 1}'), {'allowUnquotedFieldNames':'true'})
2351+
>>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'})
23562352
>>> df.select(schema.alias("json")).collect()
23572353
[Row(json=u'struct<a:bigint>')]
23582354
"""
2355+
if isinstance(json, basestring):
2356+
col = _create_column_from_literal(json)
2357+
elif isinstance(json, Column):
2358+
col = _to_java_column(json)
2359+
else:
2360+
raise TypeError("schema argument should be a column or string")
23592361

23602362
sc = SparkContext._active_spark_context
2361-
jc = sc._jvm.functions.schema_of_json(_to_java_column(col), options)
2363+
jc = sc._jvm.functions.schema_of_json(col, options)
23622364
return Column(jc)
23632365

23642366

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -742,14 +742,18 @@ case class StructsToJson(
742742
case class SchemaOfJson(
743743
child: Expression,
744744
options: Map[String, String])
745-
extends UnaryExpression with String2StringExpression with CodegenFallback {
745+
extends UnaryExpression with CodegenFallback {
746746

747747
def this(child: Expression) = this(child, Map.empty[String, String])
748748

749749
def this(child: Expression, options: Expression) = this(
750750
child = child,
751751
options = ExprUtils.convertToMapData(options))
752752

753+
override def dataType: DataType = StringType
754+
755+
override def nullable: Boolean = false
756+
753757
@transient
754758
private lazy val jsonOptions = new JSONOptions(options, "UTC")
755759

@@ -760,8 +764,17 @@ case class SchemaOfJson(
760764
factory
761765
}
762766

763-
override def convert(v: UTF8String): UTF8String = {
764-
val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser =>
767+
@transient
768+
private lazy val json = child.eval().asInstanceOf[UTF8String]
769+
770+
override def checkInputDataTypes(): TypeCheckResult = child match {
771+
case Literal(s, StringType) if s != null => super.checkInputDataTypes()
772+
case _ => TypeCheckResult.TypeCheckFailure(
773+
s"The input json should be a string literal and not null; however, got ${child.sql}.")
774+
}
775+
776+
override def eval(v: InternalRow): Any = {
777+
val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
765778
parser.nextToken()
766779
inferField(parser, jsonOptions)
767780
}
@@ -776,7 +789,7 @@ object JsonExprUtils {
776789
def evalSchemaExpr(exp: Expression): DataType = exp match {
777790
case Literal(s, StringType) => DataType.fromDDL(s.toString)
778791
case e @ SchemaOfJson(_: Literal, _) =>
779-
val ddlSchema = e.eval().asInstanceOf[UTF8String]
792+
val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
780793
DataType.fromDDL(ddlSchema.toString)
781794
case e => throw new AnalysisException(
782795
"Schema should be specified in DDL format as a string literal" +

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3626,28 +3626,38 @@ object functions {
36263626
}
36273627

36283628
/**
3629-
* Parses a column containing a JSON string and infers its schema.
3629+
* Parses a JSON string and infers its schema in DDL format.
36303630
*
3631-
* @param e a string column containing JSON data.
3631+
* @param json a JSON string.
36323632
*
36333633
* @group collection_funcs
36343634
* @since 2.4.0
36353635
*/
3636-
def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr))
3636+
def schema_of_json(json: String): Column = schema_of_json(lit(json))
36373637

36383638
/**
3639-
* Parses a column containing a JSON string and infers its schema using options.
3639+
* Parses a JSON string and infers its schema in DDL format.
36403640
*
3641-
* @param e a string column containing JSON data.
3641+
* @param json a string literal containing a JSON string.
3642+
*
3643+
* @group collection_funcs
3644+
* @since 2.4.0
3645+
*/
3646+
def schema_of_json(json: Column): Column = withExpr(new SchemaOfJson(json.expr))
3647+
3648+
/**
3649+
* Parses a JSON string and infers its schema in DDL format using options.
3650+
*
3651+
* @param json a string column containing JSON data.
36423652
* @param options options to control how the json is parsed. accepts the same options and the
36433653
* json data source. See [[DataFrameReader#json]].
36443654
* @return a column with string literal containing schema in DDL format.
36453655
*
36463656
* @group collection_funcs
36473657
* @since 3.0.0
36483658
*/
3649-
def schema_of_json(e: Column, options: java.util.Map[String, String]): Column = {
3650-
withExpr(SchemaOfJson(e.expr, options.asScala.toMap))
3659+
def schema_of_json(json: Column, options: java.util.Map[String, String]): Column = {
3660+
withExpr(SchemaOfJson(json.expr, options.asScala.toMap))
36513661
}
36523662

36533663
/**

sql/core/src/test/resources/sql-tests/inputs/json-functions.sql

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,8 @@ select to_json(array(array(1, 2, 3), array(4)));
5555
-- infer schema of json literal using options
5656
select schema_of_json('{"c1":1}', map('primitivesAsString', 'true'));
5757
select schema_of_json('{"c1":01, "c2":0.1}', map('allowNumericLeadingZeros', 'true', 'prefersDecimal', 'true'));
58-
58+
select schema_of_json(null);
59+
CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a');
60+
SELECT schema_of_json(jsonField) FROM jsonTable;
61+
-- Clean up
62+
DROP VIEW IF EXISTS jsonTable;

sql/core/src/test/resources/sql-tests/results/json-functions.sql.out

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 38
2+
-- Number of queries: 42
33

44

55
-- !query 0
@@ -318,3 +318,37 @@ select schema_of_json('{"c1":01, "c2":0.1}', map('allowNumericLeadingZeros', 'tr
318318
struct<schema_of_json({"c1":01, "c2":0.1}):string>
319319
-- !query 37 output
320320
struct<c1:bigint,c2:decimal(1,1)>
321+
322+
323+
-- !query 38
324+
select schema_of_json(null)
325+
-- !query 38 schema
326+
struct<>
327+
-- !query 38 output
328+
org.apache.spark.sql.AnalysisException
329+
cannot resolve 'schema_of_json(NULL)' due to data type mismatch: The input json should be a string literal and not null; however, got NULL.; line 1 pos 7
330+
331+
332+
-- !query 39
333+
CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a')
334+
-- !query 39 schema
335+
struct<>
336+
-- !query 39 output
337+
338+
339+
340+
-- !query 40
341+
SELECT schema_of_json(jsonField) FROM jsonTable
342+
-- !query 40 schema
343+
struct<>
344+
-- !query 40 output
345+
org.apache.spark.sql.AnalysisException
346+
cannot resolve 'schema_of_json(jsontable.`jsonField`)' due to data type mismatch: The input json should be a string literal and not null; however, got jsontable.`jsonField`.; line 1 pos 7
347+
348+
349+
-- !query 41
350+
DROP VIEW IF EXISTS jsonTable
351+
-- !query 41 schema
352+
struct<>
353+
-- !query 41 output
354+

sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
395395

396396
test("SPARK-24709: infers schemas of json strings and pass them to from_json") {
397397
val in = Seq("""{"a": [1, 2, 3]}""").toDS()
398-
val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed")
398+
val out = in.select(from_json('value, schema_of_json("""{"a": [1]}""")) as "parsed")
399399
val expected = StructType(StructField(
400400
"parsed",
401401
StructType(StructField(

0 commit comments

Comments
 (0)