Skip to content

Commit dac3139

Browse files
committed
Viewing arrays with different shapes
1 parent 05202d9 commit dac3139

20 files changed

+112
-17
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public interface BooleanNdArray extends NdArray<Boolean> {
6868
*/
6969
BooleanNdArray setBoolean(boolean value, long... coordinates);
7070

71+
@Override
72+
BooleanNdArray withShape(Shape shape);
73+
7174
@Override
7275
BooleanNdArray slice(Index... indices);
7376

ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public interface ByteNdArray extends NdArray<Byte> {
6868
*/
6969
ByteNdArray setByte(byte value, long... coordinates);
7070

71+
@Override
72+
ByteNdArray withShape(Shape shape);
73+
7174
@Override
7275
ByteNdArray slice(Index... indices);
7376

ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ default DoubleStream streamOfDoubles() {
8383
return StreamSupport.stream(scalars().spliterator(), false).mapToDouble(DoubleNdArray::getDouble);
8484
}
8585

86+
@Override
87+
DoubleNdArray withShape(Shape shape);
88+
8689
@Override
8790
DoubleNdArray slice(Index... indices);
8891

ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public interface FloatNdArray extends NdArray<Float> {
6868
*/
6969
FloatNdArray setFloat(float value, long... coordinates);
7070

71+
@Override
72+
FloatNdArray withShape(Shape shape);
73+
7174
@Override
7275
FloatNdArray slice(Index... coordinates);
7376

ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ default IntStream streamOfInts() {
8383
return StreamSupport.stream(scalars().spliterator(), false).mapToInt(IntNdArray::getInt);
8484
}
8585

86+
@Override
87+
IntNdArray withShape(Shape shape);
88+
8689
@Override
8790
IntNdArray slice(Index... indices);
8891

ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ default LongStream streamOfLongs() {
8383
return StreamSupport.stream(scalars().spliterator(), false).mapToLong(LongNdArray::getLong);
8484
}
8585

86+
@Override
87+
LongNdArray withShape(Shape shape);
88+
8689
@Override
8790
LongNdArray slice(Index... indices);
8891

ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
*/
1717
package org.tensorflow.ndarray;
1818

19+
import org.tensorflow.ndarray.buffer.DataBuffer;
20+
import org.tensorflow.ndarray.index.Index;
21+
1922
import java.util.function.BiConsumer;
2023
import java.util.function.Consumer;
2124
import java.util.stream.Stream;
2225
import java.util.stream.StreamSupport;
2326

24-
import org.tensorflow.ndarray.buffer.DataBuffer;
25-
import org.tensorflow.ndarray.index.Index;
26-
2727
/**
2828
* A data structure of N-dimensions.
2929
*
@@ -101,6 +101,32 @@ public interface NdArray<T> extends Shaped {
101101
*/
102102
NdArraySequence<? extends NdArray<T>> scalars();
103103

104+
/**
105+
* Returns a new N-dimensional view of this array with the given {@code shape}.
106+
*
107+
* <p>The provided {@code shape} must comply to the following characteristics:
108+
* <ul>
109+
* <li>new shape is known (i.e. has no unknown dimension)</li>
110+
* <li>new shape size is equal to the size of the current shape (i.e. same number of elements)</li>
111+
* </ul>
112+
* For example,
113+
* <pre>{@code
114+
* NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 1)); // ok
115+
* NdArrays.ofInts(Shape.of(2, 3).withShape(Shape.of(3, 2)); // ok
116+
* NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 2)); // not ok, sizes are different (1 != 2)
117+
* NdArrays.ofInts(Shape.of(2, 3)).withShape(Shape.unknown()); // not ok, new shape unknown
118+
* }</pre>
119+
*
120+
* <p>Any changes applied to the returned view affect the data of this array as well, as there
121+
* is no copy involved.
122+
*
123+
* @param shape the new shape to apply
124+
* @return a new array viewing the data according to the new shape, or this array if shapes are the same
125+
* @throws IllegalArgumentException if the provided {@code shape} is not compliant
126+
* @throws UnsupportedOperationException if this array does not support this operation
127+
*/
128+
NdArray<T> withShape(Shape shape);
129+
104130
/**
105131
* Creates a multi-dimensional view (or slice) of this array by mapping one or more dimensions
106132
* to the given index selectors.

ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public interface ShortNdArray extends NdArray<Short> {
6868
*/
6969
ShortNdArray setShort(short value, long... coordinates);
7070

71+
@Override
72+
ShortNdArray withShape(Shape shape);
73+
7174
@Override
7275
ShortNdArray slice(Index... coordinates);
7376

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.tensorflow.ndarray.NdArray;
2020
import org.tensorflow.ndarray.NdArraySequence;
21+
import org.tensorflow.ndarray.Shape;
2122
import org.tensorflow.ndarray.impl.AbstractNdArray;
2223
import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace;
2324
import org.tensorflow.ndarray.impl.sequence.FastElementSequence;
@@ -43,18 +44,29 @@ public NdArraySequence<U> elements(int dimensionIdx) {
4344
DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1);
4445
try {
4546
DataBufferWindow<? extends DataBuffer<T>> elemWindow = buffer().window(elemDims.physicalSize());
46-
U element = instantiate(elemWindow.buffer(), elemDims);
47+
U element = instantiateView(elemWindow.buffer(), elemDims);
4748
return new FastElementSequence(this, dimensionIdx, element, elemWindow);
4849
} catch (UnsupportedOperationException e) {
4950
// If buffer windows are not supported, fallback to slicing (and slower) sequence
5051
return new SlicingElementSequence<>(this, dimensionIdx, elemDims);
5152
}
5253
}
5354

55+
@Override
56+
public U withShape(Shape shape) {
57+
if (shape == this.shape()) {
58+
return (U)this;
59+
}
60+
if (shape == null || shape.isUnknown() || shape.size() != this.shape().size()) {
61+
throw new IllegalArgumentException("Shape " + shape + " cannot be used to reshape ndarray of shape " + this.shape());
62+
}
63+
return instantiateView(buffer(), DimensionalSpace.create(shape));
64+
}
65+
5466
@Override
5567
public U slice(long position, DimensionalSpace sliceDimensions) {
5668
DataBuffer<T> sliceBuffer = buffer().slice(position, sliceDimensions.physicalSize());
57-
return instantiate(sliceBuffer, sliceDimensions);
69+
return instantiateView(sliceBuffer, sliceDimensions);
5870
}
5971

6072
@Override
@@ -147,7 +159,7 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) {
147159

148160
abstract protected DataBuffer<T> buffer();
149161

150-
abstract U instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions);
162+
abstract U instantiateView(DataBuffer<T> buffer, DimensionalSpace dimensions);
151163

152164
long positionOf(long[] coords, boolean isValue) {
153165
if (coords == null || coords.length == 0) {

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected BooleanDenseNdArray(BooleanDataBuffer buffer, Shape shape) {
7373
}
7474

7575
@Override
76-
BooleanDenseNdArray instantiate(DataBuffer<Boolean> buffer, DimensionalSpace dimensions) {
76+
BooleanDenseNdArray instantiateView(DataBuffer<Boolean> buffer, DimensionalSpace dimensions) {
7777
return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions);
7878
}
7979

0 commit comments

Comments
 (0)