@@ -985,6 +985,7 @@ private[spark] object SerDe extends Serializable {
985985 val m : DenseMatrix = obj.asInstanceOf [DenseMatrix ]
986986 val bytes = new Array [Byte ](8 * m.values.size)
987987 val order = ByteOrder .nativeOrder()
988+ val isTransposed = if (m.isTransposed) 1 else 0
988989 ByteBuffer .wrap(bytes).order(order).asDoubleBuffer().put(m.values)
989990
990991 out.write(Opcodes .BININT )
@@ -994,19 +995,22 @@ private[spark] object SerDe extends Serializable {
994995 out.write(Opcodes .BINSTRING )
995996 out.write(PickleUtils .integer_to_bytes(bytes.length))
996997 out.write(bytes)
997- out.write(Opcodes .TUPLE3 )
998+ out.write(Opcodes .BININT )
999+ out.write(PickleUtils .integer_to_bytes(isTransposed))
1000+ out.write(Opcodes .TUPLE )
9981001 }
9991002
10001003 def construct (args : Array [Object ]): Object = {
1001- if (args.length != 3 ) {
1002- throw new PickleException (" should be 3 " )
1004+ if (args.length != 4 ) {
1005+ throw new PickleException (" should be 4 " )
10031006 }
10041007 val bytes = getBytes(args(2 ))
10051008 val n = bytes.length / 8
10061009 val values = new Array [Double ](n)
10071010 val order = ByteOrder .nativeOrder()
1011+ val isTransposed = args(3 ).asInstanceOf [Boolean ]
10081012 ByteBuffer .wrap(bytes).order(order).asDoubleBuffer().get(values)
1009- new DenseMatrix (args(0 ).asInstanceOf [Int ], args(1 ).asInstanceOf [Int ], values)
1013+ new DenseMatrix (args(0 ).asInstanceOf [Int ], args(1 ).asInstanceOf [Int ], values, isTransposed )
10101014 }
10111015 }
10121016
0 commit comments