Skip to content

Commit be23fbf

Browse files
committed
test new builders
1 parent 6516637 commit be23fbf

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api030/shm/ShmBuilder.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.nio.ByteBuffer;
2929
import java.util.Arrays;
3030

31+
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
3132
import org.tensorflow.types.TFloat32;
3233
import org.tensorflow.types.TFloat64;
3334
import org.tensorflow.types.TInt32;
@@ -101,7 +102,19 @@ private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throw
101102
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
102103
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
103104
ByteBuffer buff = shma.getDataBufferNoHeader();
104-
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
105+
long tt = System.currentTimeMillis();
106+
ByteDataBuffer tensorData = tensor.asRawTensor().data();
107+
for (int i = 0; i < buff.capacity(); i ++) {
108+
buff.put(tensorData.getByte(i));
109+
}
110+
System.out.println("TIME 1: " + (System.currentTimeMillis() - tt) / 1000);
111+
buff.rewind();
112+
tt = System.currentTimeMillis();
113+
byte[] flat = new byte[buff.capacity()];
114+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
115+
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
116+
shma.setBuffer(buff2);
117+
System.out.println("TIME 2: " + (System.currentTimeMillis() - tt) / 1000);
105118
if (PlatformDetection.isWindows()) shma.close();
106119
}
107120

@@ -127,7 +140,19 @@ private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) thr
127140

128141
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
129142
ByteBuffer buff = shma.getDataBufferNoHeader();
130-
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
143+
long tt = System.currentTimeMillis();
144+
ByteDataBuffer tensorData = tensor.asRawTensor().data();
145+
for (int i = 0; i < buff.capacity(); i ++) {
146+
buff.put(tensorData.getByte(i));
147+
}
148+
System.out.println("TIME 1: " + (System.currentTimeMillis() - tt) / 1000);
149+
buff.rewind();
150+
tt = System.currentTimeMillis();
151+
byte[] flat = new byte[buff.capacity()];
152+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
153+
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
154+
shma.setBuffer(buff2);
155+
System.out.println("TIME 2: " + (System.currentTimeMillis() - tt) / 1000);
131156
if (PlatformDetection.isWindows()) shma.close();
132157
}
133158

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api030/shm/TensorBuilder.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ private static TUint8 buildUByte(SharedMemoryArray tensor)
9797
if (!tensor.isNumpyFormat())
9898
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
9999
ByteBuffer buff = tensor.getDataBufferNoHeader();
100-
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(buff.array(), false);
100+
byte[] flat = new byte[buff.capacity()];
101+
buff.get(flat);
102+
buff.rewind();
103+
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
101104
TUint8 ndarray = Tensor.of(TUint8.class, Shape.of(ogShape), dataBuffer);
102105
return ndarray;
103106
}
@@ -112,7 +115,10 @@ private static TInt32 buildInt(SharedMemoryArray tensor)
112115
if (!tensor.isNumpyFormat())
113116
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
114117
ByteBuffer buff = tensor.getDataBufferNoHeader();
115-
IntDataBuffer dataBuffer = RawDataBufferFactory.create(buff.asIntBuffer().array(), false);
118+
int[] flat = new int[buff.capacity() / 4];
119+
buff.asIntBuffer().get(flat);
120+
buff.rewind();
121+
IntDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
116122
TInt32 ndarray = TInt32.tensorOf(Shape.of(ogShape),
117123
dataBuffer);
118124
return ndarray;
@@ -128,7 +134,10 @@ private static TInt64 buildLong(SharedMemoryArray tensor)
128134
if (!tensor.isNumpyFormat())
129135
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
130136
ByteBuffer buff = tensor.getDataBufferNoHeader();
131-
LongDataBuffer dataBuffer = RawDataBufferFactory.create(buff.asLongBuffer().array(), false);
137+
long[] flat = new long[buff.capacity() / 8];
138+
buff.asLongBuffer().get(flat);
139+
buff.rewind();
140+
LongDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
132141
TInt64 ndarray = TInt64.tensorOf(Shape.of(ogShape),
133142
dataBuffer);
134143
return ndarray;
@@ -144,7 +153,10 @@ private static TFloat32 buildFloat(SharedMemoryArray tensor)
144153
if (!tensor.isNumpyFormat())
145154
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
146155
ByteBuffer buff = tensor.getDataBufferNoHeader();
147-
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(buff.asFloatBuffer().array(), false);
156+
float[] flat = new float[buff.capacity() / 4];
157+
buff.asFloatBuffer().get(flat);
158+
buff.rewind();
159+
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
148160
TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
149161
return ndarray;
150162
}
@@ -159,7 +171,10 @@ private static TFloat64 buildDouble(SharedMemoryArray tensor)
159171
if (!tensor.isNumpyFormat())
160172
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
161173
ByteBuffer buff = tensor.getDataBufferNoHeader();
162-
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(buff.asDoubleBuffer().array(), false);
174+
double[] flat = new double[buff.capacity() / 8];
175+
buff.asDoubleBuffer().get(flat);
176+
buff.rewind();
177+
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
163178
TFloat64 ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
164179
return ndarray;
165180
}

0 commit comments

Comments
 (0)