Skip to content

Commit 004a37f

Browse files
committed
Cast boolean to int
1 parent 151f3b6 commit 004a37f

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,7 @@ private[spark] object SerDe extends Serializable {
988988
val isTransposed = if (m.isTransposed) 1 else 0
989989
ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
990990

991+
out.write(Opcodes.MARK)
991992
out.write(Opcodes.BININT)
992993
out.write(PickleUtils.integer_to_bytes(m.numRows))
993994
out.write(Opcodes.BININT)
@@ -1008,7 +1009,7 @@ private[spark] object SerDe extends Serializable {
10081009
val n = bytes.length / 8
10091010
val values = new Array[Double](n)
10101011
val order = ByteOrder.nativeOrder()
1011-
val isTransposed = args(3).asInstanceOf[Boolean]
1012+
val isTransposed = if (args(3).asInstanceOf[Int] == 1) true else false
10121013
ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
10131014
new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
10141015
}

python/pyspark/mllib/linalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,8 @@ def __init__(self, numRows, numCols, values, isTransposed=False):
671671

672672
def __reduce__(self):
673673
return DenseMatrix, (
674-
self.numRows, self.numCols, self.values.tostring(), self.isTransposed)
674+
self.numRows, self.numCols, self.values.tostring(),
675+
int(self.isTransposed))
675676

676677
def toArray(self):
677678
"""

python/pyspark/mllib/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_serialize(self):
8585
self._test_serialize(DenseVector(pyarray.array('d', range(10))))
8686
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
8787
self._test_serialize(SparseVector(3, {}))
88-
# self._test_serialize(DenseMatrix(2, 3, range(6)))
88+
self._test_serialize(DenseMatrix(2, 3, range(6)))
8989

9090
def test_dot(self):
9191
sv = SparseVector(4, {1: 1, 3: 2})

0 commit comments

Comments
 (0)