Skip to content

Commit c7e308b

Browse files
author
Punya Biswal
committed
Support java iterable types in POJOs
1 parent 5e00685 commit c7e308b

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20+
import java.lang.{Iterable => JavaIterable}
2021
import java.util.{Map => JavaMap}
2122

2223
import scala.collection.mutable.HashMap
@@ -49,6 +50,16 @@ object CatalystTypeConverters {
4950
case (s: Seq[_], arrayType: ArrayType) =>
5051
s.map(convertToCatalyst(_, arrayType.elementType))
5152

53+
case (jit: JavaIterable[_], arrayType: ArrayType) => {
54+
val iter = jit.iterator
55+
var listOfItems: List[Any] = List()
56+
while (iter.hasNext) {
57+
val item = iter.next()
58+
listOfItems :+= convertToCatalyst(item, arrayType.elementType)
59+
}
60+
listOfItems
61+
}
62+
5263
case (s: Array[_], arrayType: ArrayType) =>
5364
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
5465

@@ -124,6 +135,15 @@ object CatalystTypeConverters {
124135
extractOption(item) match {
125136
case a: Array[_] => a.toSeq.map(elementConverter)
126137
case s: Seq[_] => s.map(elementConverter)
138+
case i: JavaIterable[_] => {
139+
val iter = i.iterator
140+
var convertedIterable: List[Any] = List()
141+
while (iter.hasNext) {
142+
val item = iter.next()
143+
convertedIterable :+= elementConverter(item)
144+
}
145+
convertedIterable
146+
}
127147
case null => null
128148
}
129149
}

sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ private [sql] object JavaTypeInference {
7575
val (dataType, nullable) = inferDataType(typeToken.getComponentType)
7676
(ArrayType(dataType, nullable), true)
7777

78+
case _ if iterableType.isAssignableFrom(typeToken) =>
79+
val (dataType, nullable) = inferDataType(elementType(typeToken))
80+
(ArrayType(dataType, nullable), true)
81+
7882
case _ if mapType.isAssignableFrom(typeToken) =>
7983
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
8084
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.io.Serializable;
2121
import java.util.Arrays;
22+
import java.util.List;
2223
import java.util.Map;
2324

2425
import com.google.common.collect.ImmutableMap;
@@ -112,6 +113,7 @@ public static class Bean implements Serializable {
112113
private double a = 0.0;
113114
private Integer[] b = new Integer[]{0, 1};
114115
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
116+
private List<String> d = Arrays.asList("floppy", "disk");
115117

116118
public double getA() {
117119
return a;
@@ -124,6 +126,10 @@ public Integer[] getB() {
124126
public Map<String, int[]> getC() {
125127
return c;
126128
}
129+
130+
public List<String> getD() {
131+
return d;
132+
}
127133
}
128134

129135
@Test
@@ -142,7 +148,10 @@ public void testCreateDataFrameFromJavaBeans() {
142148
Assert.assertEquals(
143149
new StructField("c", mapType, true, Metadata.empty()),
144150
schema.apply("c"));
145-
Row first = df.select("a", "b", "c").first();
151+
Assert.assertEquals(
152+
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
153+
schema.apply("d"));
154+
Row first = df.select("a", "b", "c", "d").first();
146155
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
147156
// Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
148157
// verify that it has the expected length, and contains expected elements.
@@ -155,6 +164,11 @@ public void testCreateDataFrameFromJavaBeans() {
155164
Assert.assertArrayEquals(
156165
bean.getC().get("hello"),
157166
Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
167+
Seq<String> d = first.getAs(3);
168+
Assert.assertEquals(bean.getD().size(), d.length());
169+
for (int i = 0; i < d.length(); i++) {
170+
Assert.assertEquals(bean.getD().get(i), d.apply(i));
171+
}
158172
}
159173

160174
}

0 commit comments

Comments
 (0)