Skip to content

Commit 151f3b6

Browse files
committed
[WIP] Add isTransposed to pickle DenseMatrix
1 parent cc0b90a commit 151f3b6

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,7 @@ private[spark] object SerDe extends Serializable {
985985
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
986986
val bytes = new Array[Byte](8 * m.values.size)
987987
val order = ByteOrder.nativeOrder()
988+
val isTransposed = if (m.isTransposed) 1 else 0
988989
ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
989990

990991
out.write(Opcodes.BININT)
@@ -994,19 +995,22 @@ private[spark] object SerDe extends Serializable {
994995
out.write(Opcodes.BINSTRING)
995996
out.write(PickleUtils.integer_to_bytes(bytes.length))
996997
out.write(bytes)
997-
out.write(Opcodes.TUPLE3)
998+
out.write(Opcodes.BININT)
999+
out.write(PickleUtils.integer_to_bytes(isTransposed))
1000+
out.write(Opcodes.TUPLE)
9981001
}
9991002

10001003
def construct(args: Array[Object]): Object = {
1001-
if (args.length != 3) {
1002-
throw new PickleException("should be 3")
1004+
if (args.length != 4) {
1005+
throw new PickleException("should be 4")
10031006
}
10041007
val bytes = getBytes(args(2))
10051008
val n = bytes.length / 8
10061009
val values = new Array[Double](n)
10071010
val order = ByteOrder.nativeOrder()
1011+
val isTransposed = args(3).asInstanceOf[Boolean]
10081012
ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
1009-
new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values)
1013+
new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
10101014
}
10111015
}
10121016

0 commit comments

Comments
 (0)