Skip to content

Commit 914a374

Browse files
committed
fill with map.
1 parent 185c67e commit 914a374

File tree

3 files changed

+137
-27
lines changed

3 files changed

+137
-27
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -727,9 +727,13 @@ def dropna(self, how='any', thresh=None, subset=None):
727727
return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
728728

729729
def fillna(self, value, subset=None):
730-
"""Fill null values.
730+
"""Replace null values.
731731
732-
:param value: int, long, float, or string. Value to replace null values with.
732+
:param value: int, long, float, string, or dict.
733+
Value to replace null values with.
734+
If the value is a dict, then `subset` is ignored and `value` must be a mapping
735+
from column name (string) to replacement value. The replacement value must be
736+
an int, long, float, or string.
733737
:param subset: optional list of column names to consider.
734738
Columns specified in subset that do not have matching data type are ignored.
735739
For example, if `value` is a string, and subset contains a non-string column,
@@ -741,14 +745,24 @@ def fillna(self, value, subset=None):
741745
5 50 Bob
742746
50 50 Tom
743747
50 50 null
748+
749+
>>> df4.fillna({'age': 50, 'name': 'unknown'}).show()
750+
age height name
751+
10 80 Alice
752+
5 null Bob
753+
50 null Tom
754+
50 null unknown
744755
"""
745-
if not isinstance(value, (float, int, long, basestring)):
746-
raise ValueError("value should be a float, int, long, or string")
756+
if not isinstance(value, (float, int, long, basestring, dict)):
757+
raise ValueError("value should be a float, int, long, string, or dict")
747758

748759
if isinstance(value, (int, long)):
749760
value = float(value)
750761

751-
if subset is None:
762+
if isinstance(value, dict):
763+
value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client)
764+
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
765+
elif subset is None:
752766
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
753767
else:
754768
if isinstance(subset, basestring):

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 95 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.{lang => jl}
21+
22+
import scala.collection.JavaConversions._
23+
2024
import org.apache.spark.sql.catalyst.expressions._
2125
import org.apache.spark.sql.functions._
2226
import org.apache.spark.sql.types._
@@ -48,8 +52,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
4852
def drop(threshold: Int, cols: Array[String]): DataFrame = drop(threshold, cols.toSeq)
4953

5054
/**
51-
* Returns a new [[DataFrame ]] that drops rows containing less than `threshold` non-null
52-
* values in the specified columns.
55+
* (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing less than
56+
* `threshold` non-null values in the specified columns.
5357
*/
5458
def drop(threshold: Int, cols: Seq[String]): DataFrame = {
5559
// Filtering condition -- drop rows that have less than `threshold` non-null,
@@ -75,22 +79,15 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
7579
def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
7680

7781
/**
78-
* Returns a new [[DataFrame ]] that replaces null values in specified numeric columns.
79-
* If a specified column is not a numeric column, it is ignored.
82+
* (Scala-specific) Returns a new [[DataFrame ]] that replaces null values in specified
83+
* numeric columns. If a specified column is not a numeric column, it is ignored.
8084
*/
8185
def fill(value: Double, cols: Seq[String]): DataFrame = {
8286
val columnEquals = df.sqlContext.analyzer.resolver
8387
val projections = df.schema.fields.map { f =>
8488
// Only fill if the column is part of the cols list.
85-
if (cols.exists(col => columnEquals(f.name, col))) {
86-
f.dataType match {
87-
case _: DoubleType =>
88-
coalesce(df.col(f.name), lit(value)).as(f.name)
89-
case typ: NumericType =>
90-
coalesce(df.col(f.name), lit(value).cast(typ)).as(f.name)
91-
case _ =>
92-
df.col(f.name)
93-
}
89+
if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
90+
fillDouble(f, value)
9491
} else {
9592
df.col(f.name)
9693
}
@@ -105,24 +102,100 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
105102
def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
106103

107104
/**
108-
* Returns a new [[DataFrame ]] that replaces null values in specified string columns.
109-
* If a specified column is not a string column, it is ignored.
105+
* (Scala-specific) Returns a new [[DataFrame ]] that replaces null values in
106+
* specified string columns. If a specified column is not a string column, it is ignored.
110107
*/
111108
def fill(value: String, cols: Seq[String]): DataFrame = {
112109
val columnEquals = df.sqlContext.analyzer.resolver
113110
val projections = df.schema.fields.map { f =>
114111
// Only fill if the column is part of the cols list.
115-
if (cols.exists(col => columnEquals(f.name, col))) {
116-
f.dataType match {
117-
case _: StringType =>
118-
coalesce(df.col(f.name), lit(value)).as(f.name)
119-
case _ =>
120-
df.col(f.name)
121-
}
112+
if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
113+
fillString(f, value)
122114
} else {
123115
df.col(f.name)
124116
}
125117
}
126118
df.select(projections : _*)
127119
}
120+
121+
/**
122+
* Returns a new [[DataFrame ]] that replaces null values.
123+
*
124+
* The key of the map is the column name, and the value of the map is the replacement value.
125+
* The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`.
126+
*
127+
* For example, the following replaces null values in column "A" with string "unknown", and
128+
* null values in column "B" with numeric value 1.0.
129+
* {{{
130+
* import com.google.common.collect.ImmutableMap;
131+
* df.na.fill(ImmutableMap.<String, Object>builder()
132+
* .put("A", "unknown")
133+
* .put("B", 1.0)
134+
* .build());
135+
* }}}
136+
*/
137+
def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill(valueMap.toSeq)
138+
139+
/**
140+
* (Scala-specific) Returns a new [[DataFrame ]] that replaces null values.
141+
*
142+
* The key of the map is the column name, and the value of the map is the replacement value.
143+
* The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`.
144+
*
145+
* For example, the following replaces null values in column "A" with string "unknown", and
146+
* null values in column "B" with numeric value 1.0.
147+
* {{{
148+
* df.na.fill(Map(
149+
* "A" -> "unknown",
150+
* "B" -> 1.0
151+
* ))
152+
* }}}
153+
*/
154+
def fill(valueMap: Map[String, Any]): DataFrame = fill(valueMap.toSeq)
155+
156+
private def fill(valueMap: Seq[(String, Any)]): DataFrame = {
157+
// Error handling
158+
valueMap.foreach { case (colName, replaceValue) =>
159+
// Check column name exists
160+
df.resolve(colName)
161+
162+
// Check data type
163+
replaceValue match {
164+
case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String =>
165+
// This is good
166+
case _ => throw new IllegalArgumentException(
167+
s"Does not support value type ${replaceValue.getClass.getName} ($replaceValue).")
168+
}
169+
}
170+
171+
val columnEquals = df.sqlContext.analyzer.resolver
172+
val pairs = valueMap.toSeq
173+
174+
val projections = df.schema.fields.map { f =>
175+
pairs.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
176+
v match {
177+
case v: jl.Float => fillDouble(f, v.toDouble)
178+
case v: jl.Double => fillDouble(f, v)
179+
case v: jl.Long => fillDouble(f, v.toDouble)
180+
case v: jl.Integer => fillDouble(f, v.toDouble)
181+
case v: String => fillString(f, v)
182+
}
183+
}.getOrElse(df.col(f.name))
184+
}
185+
df.select(projections : _*)
186+
}
187+
188+
/**
189+
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
190+
*/
191+
private def fillDouble(col: StructField, replacement: Double): Column = {
192+
coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
193+
}
194+
195+
/**
196+
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
197+
*/
198+
private def fillString(col: StructField, replacement: String): Column = {
199+
coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
200+
}
128201
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.collection.JavaConversions._
21+
2022
import org.apache.spark.sql.test.TestSQLContext.implicits._
2123

2224

@@ -111,4 +113,25 @@ class DataFrameNaFunctionsSuite extends QueryTest {
111113
Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
112114
Row("test", null))
113115
}
116+
117+
test("fill with map") {
118+
val df = Seq[(String, String, java.lang.Long, java.lang.Double)](
119+
(null, null, null, null)).toDF("a", "b", "c", "d")
120+
checkAnswer(
121+
df.na.fill(Map(
122+
"a" -> "test",
123+
"c" -> 1,
124+
"d" -> 2.2
125+
)),
126+
Row("test", null, 1, 2.2))
127+
128+
// Test Java version
129+
checkAnswer(
130+
df.na.fill(mapAsJavaMap(Map(
131+
"a" -> "test",
132+
"c" -> 1,
133+
"d" -> 2.2
134+
))),
135+
Row("test", null, 1, 2.2))
136+
}
114137
}

0 commit comments

Comments
 (0)