@@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
176176 }
177177
178178 override def serialize (obj : Any ): Row = {
179- val row = new GenericMutableRow (4 )
180179 obj match {
181180 case SparseVector (size, indices, values) =>
181+ val row = new GenericMutableRow (4 )
182182 row.setByte(0 , 0 )
183183 row.setInt(1 , size)
184184 row.update(2 , indices.toSeq)
185185 row.update(3 , values.toSeq)
186+ row
186187 case DenseVector (values) =>
188+ val row = new GenericMutableRow (4 )
187189 row.setByte(0 , 1 )
188190 row.setNullAt(1 )
189191 row.setNullAt(2 )
190192 row.update(3 , values.toSeq)
193+ row
194+ // TODO: There are bugs in UDT serialization because we don't have a clear separation between
195+ // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
196+ // TODO: deserialize may get called twice. See SPARK-7186.
197+ case row : Row =>
198+ row
191199 }
192- row
193200 }
194201
195202 override def deserialize (datum : Any ): Vector = {
196203 datum match {
197- // TODO: something wrong with UDT serialization
198- case v : Vector =>
199- v
200204 case row : Row =>
201205 require(row.length == 4 ,
202206 s " VectorUDT.deserialize given row with length ${row.length} but requires length == 4 " )
@@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
211215 val values = row.getAs[Iterable [Double ]](3 ).toArray
212216 new DenseVector (values)
213217 }
218+ // TODO: There are bugs in UDT serialization because we don't have a clear separation between
219+ // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
220+ // TODO: deserialize may get called twice. See SPARK-7186.
221+ case v : Vector =>
222+ v
214223 }
215224 }
216225
0 commit comments