Skip to content

Commit 9564ebb

Browse files
author
vidmantas zemleris
committed
[SPARK-6994][SQL] Add fieldIndex to schema (StructType)
1 parent 327ebf0 commit 9564ebb

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
10251025

10261026
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
10271027
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
1028+
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
10281029

10291030
/**
10301031
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
@@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
10491050
StructType(fields.filter(f => names.contains(f.name)))
10501051
}
10511052

1053+
/**
1054+
* Returns index of a given field
1055+
*/
1056+
def fieldIndex(name: String): Int = {
1057+
nameToIndex.getOrElse(name,
1058+
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
1059+
}
1060+
10521061
protected[sql] def toAttributes: Seq[AttributeReference] =
10531062
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
10541063

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite {
5656
}
5757
}
5858

59+
test("extract field index from a StructType") {
60+
val struct = StructType(
61+
StructField("a", LongType) ::
62+
StructField("b", FloatType) :: Nil)
63+
64+
assert(struct.fieldIndex("a") === 0)
65+
assert(struct.fieldIndex("b") === 1)
66+
67+
intercept[IllegalArgumentException] {
68+
struct.fieldIndex("non_existent")
69+
}
70+
}
71+
5972
def checkDataTypeJsonRepr(dataType: DataType): Unit = {
6073
test(s"JSON - $dataType") {
6174
assert(DataType.fromJson(dataType.json) === dataType)

0 commit comments

Comments
 (0)