17
17
18
18
package org .apache .spark .sql
19
19
20
+ import java .{lang => jl }
21
+
22
+ import scala .collection .JavaConversions ._
23
+
20
24
import org .apache .spark .sql .catalyst .expressions ._
21
25
import org .apache .spark .sql .functions ._
22
26
import org .apache .spark .sql .types ._
@@ -48,8 +52,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
48
52
def drop (threshold : Int , cols : Array [String ]): DataFrame = drop(threshold, cols.toSeq)
49
53
50
54
/**
51
- * Returns a new [[DataFrame ]] that drops rows containing less than `threshold` non-null
52
- * values in the specified columns.
55
+ * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing less than
56
+ * `threshold` non-null values in the specified columns.
53
57
*/
54
58
def drop (threshold : Int , cols : Seq [String ]): DataFrame = {
55
59
// Filtering condition -- drop rows that have less than `threshold` non-null,
@@ -75,22 +79,15 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
75
79
def fill (value : Double , cols : Array [String ]): DataFrame = fill(value, cols.toSeq)
76
80
77
81
/**
78
- * Returns a new [[DataFrame ]] that replaces null values in specified numeric columns.
79
- * If a specified column is not a numeric column, it is ignored.
82
+ * (Scala-specific) Returns a new [[DataFrame ]] that replaces null values in specified
83
+ * numeric columns. If a specified column is not a numeric column, it is ignored.
80
84
*/
81
85
def fill (value : Double , cols : Seq [String ]): DataFrame = {
82
86
val columnEquals = df.sqlContext.analyzer.resolver
83
87
val projections = df.schema.fields.map { f =>
84
88
// Only fill if the column is part of the cols list.
85
- if (cols.exists(col => columnEquals(f.name, col))) {
86
- f.dataType match {
87
- case _ : DoubleType =>
88
- coalesce(df.col(f.name), lit(value)).as(f.name)
89
- case typ : NumericType =>
90
- coalesce(df.col(f.name), lit(value).cast(typ)).as(f.name)
91
- case _ =>
92
- df.col(f.name)
93
- }
89
+ if (f.dataType.isInstanceOf [NumericType ] && cols.exists(col => columnEquals(f.name, col))) {
90
+ fillDouble(f, value)
94
91
} else {
95
92
df.col(f.name)
96
93
}
@@ -105,24 +102,100 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
105
102
def fill (value : String , cols : Array [String ]): DataFrame = fill(value, cols.toSeq)
106
103
107
104
/**
108
- * Returns a new [[DataFrame ]] that replaces null values in specified string columns.
109
- * If a specified column is not a string column, it is ignored.
105
+ * (Scala-specific) Returns a new [[DataFrame ]] that replaces null values in
106
+ * specified string columns. If a specified column is not a string column, it is ignored.
110
107
*/
111
108
def fill (value : String , cols : Seq [String ]): DataFrame = {
112
109
val columnEquals = df.sqlContext.analyzer.resolver
113
110
val projections = df.schema.fields.map { f =>
114
111
// Only fill if the column is part of the cols list.
115
- if (cols.exists(col => columnEquals(f.name, col))) {
116
- f.dataType match {
117
- case _ : StringType =>
118
- coalesce(df.col(f.name), lit(value)).as(f.name)
119
- case _ =>
120
- df.col(f.name)
121
- }
112
+ if (f.dataType.isInstanceOf [StringType ] && cols.exists(col => columnEquals(f.name, col))) {
113
+ fillString(f, value)
122
114
} else {
123
115
df.col(f.name)
124
116
}
125
117
}
126
118
df.select(projections : _* )
127
119
}
120
+
121
+ /**
122
+ * Returns a new [[DataFrame ]] that replaces null values.
123
+ *
124
+ * The key of the map is the column name, and the value of the map is the replacement value.
125
+ * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`.
126
+ *
127
+ * For example, the following replaces null values in column "A" with string "unknown", and
128
+ * null values in column "B" with numeric value 1.0.
129
+ * {{{
130
+ * import com.google.common.collect.ImmutableMap;
131
+ * df.na.fill(ImmutableMap.<String, Object>builder()
132
+ * .put("A", "unknown")
133
+ * .put("B", 1.0)
134
+ * .build());
135
+ * }}}
136
+ */
137
+ def fill (valueMap : java.util.Map [String , Any ]): DataFrame = fill(valueMap.toSeq)
138
+
139
+ /**
140
+ * (Scala-specific) Returns a new [[DataFrame ]] that replaces null values.
141
+ *
142
+ * The key of the map is the column name, and the value of the map is the replacement value.
143
+ * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`.
144
+ *
145
+ * For example, the following replaces null values in column "A" with string "unknown", and
146
+ * null values in column "B" with numeric value 1.0.
147
+ * {{{
148
+ * df.na.fill(Map(
149
+ * "A" -> "unknown",
150
+ * "B" -> 1.0
151
+ * ))
152
+ * }}}
153
+ */
154
+ def fill (valueMap : Map [String , Any ]): DataFrame = fill(valueMap.toSeq)
155
+
156
+ private def fill (valueMap : Seq [(String , Any )]): DataFrame = {
157
+ // Error handling
158
+ valueMap.foreach { case (colName, replaceValue) =>
159
+ // Check column name exists
160
+ df.resolve(colName)
161
+
162
+ // Check data type
163
+ replaceValue match {
164
+ case _ : jl.Double | _ : jl.Float | _ : jl.Integer | _ : jl.Long | _ : String =>
165
+ // This is good
166
+ case _ => throw new IllegalArgumentException (
167
+ s " Does not support value type ${replaceValue.getClass.getName} ( $replaceValue). " )
168
+ }
169
+ }
170
+
171
+ val columnEquals = df.sqlContext.analyzer.resolver
172
+ val pairs = valueMap.toSeq
173
+
174
+ val projections = df.schema.fields.map { f =>
175
+ pairs.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
176
+ v match {
177
+ case v : jl.Float => fillDouble(f, v.toDouble)
178
+ case v : jl.Double => fillDouble(f, v)
179
+ case v : jl.Long => fillDouble(f, v.toDouble)
180
+ case v : jl.Integer => fillDouble(f, v.toDouble)
181
+ case v : String => fillString(f, v)
182
+ }
183
+ }.getOrElse(df.col(f.name))
184
+ }
185
+ df.select(projections : _* )
186
+ }
187
+
188
+ /**
189
+ * Returns a [[Column ]] expression that replaces null value in `col` with `replacement`.
190
+ */
191
+ private def fillDouble (col : StructField , replacement : Double ): Column = {
192
+ coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
193
+ }
194
+
195
+ /**
196
+ * Returns a [[Column ]] expression that replaces null value in `col` with `replacement`.
197
+ */
198
+ private def fillString (col : StructField , replacement : String ): Column = {
199
+ coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
200
+ }
128
201
}
0 commit comments