Skip to content

Commit d2aed4a

Browse files
authored
Kotlin friendly names (Shape.get)
1 parent 3649959 commit d2aed4a

File tree

6 files changed

+97
-37
lines changed

6 files changed

+97
-37
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
*.iml
2+
.idea
23
target

ndarray/src/main/java/org/tensorflow/ndarray/Shape.java

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.tensorflow.ndarray;
1919

20+
import java.util.ArrayList;
2021
import java.util.Arrays;
22+
import java.util.List;
2123

2224
/**
2325
* The shape of a Tensor or {@link NdArray}.
@@ -74,8 +76,8 @@ public static Shape scalar() {
7476
* Shape scalar = Shape.of()
7577
* }</pre>
7678
*
77-
* @param dimensionSizes number of elements in each dimension of this shape, if any, or
78-
* {@link Shape#UNKNOWN_SIZE} if unknown.
79+
* @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link
80+
* Shape#UNKNOWN_SIZE} if unknown.
7981
* @return a new shape
8082
*/
8183
public static Shape of(long... dimensionSizes) {
@@ -108,13 +110,34 @@ public long size() {
108110
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
109111
*
110112
* @param i the index of the dimension to get the size for. If this Shape has a known number of
111-
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative, in which
112-
* case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
113-
* size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
113+
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative, in
114+
* which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
115+
* returns the size of the last dimension, {@code size(-2)} the size of the second to last
116+
* dimension etc.
114117
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
115118
* otherwise.
119+
* @deprecated Renamed to {@link #get(int)}.
116120
*/
117-
public long size(int i) {
121+
@Deprecated
122+
public long size(int i){
123+
return get(i);
124+
}
125+
126+
/**
127+
* The size of the dimension with the given index.
128+
*
129+
* <p>If {@link Shape#isUnknown()} is true or the size of the dimension with the given index has
130+
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
131+
*
132+
* @param i the index of the dimension to get the size for. If this Shape has a known number of
133+
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative, in
134+
* which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
135+
* returns the size of the last dimension, {@code size(-2)} the size of the second to last
136+
* dimension etc.
137+
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
138+
* otherwise.
139+
*/
140+
public long get(int i) {
118141
if (dimensionSizes == null) {
119142
return UNKNOWN_SIZE;
120143
} else if (i >= 0) {
@@ -177,6 +200,24 @@ public long[] asArray() {
177200
}
178201
}
179202

203+
/**
204+
* Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change
205+
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
206+
*/
207+
public List<Long> toListOrNull() {
208+
long[] array = asArray();
209+
if (array == null) {
210+
return null;
211+
}
212+
213+
List<Long> list = new ArrayList<>(array.length);
214+
for (long l : array) {
215+
list.add(l);
216+
}
217+
218+
return list;
219+
}
220+
180221
@Override
181222
public int hashCode() {
182223
return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode();
@@ -186,6 +227,7 @@ public int hashCode() {
186227
* Equals implementation for Shapes. Two Shapes are considered equal iff:
187228
*
188229
* <p>
230+
*
189231
* <ul>
190232
* <li>the number of dimensions is defined and equal for both
191233
* <li>the size of each dimension is defined and equal for both
@@ -236,7 +278,8 @@ public Shape head() {
236278
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
237279
* shape
238280
*
239-
* @param n the number of leading dimensions to get, must be &lt;= than {@link Shape#numDimensions()}
281+
* @param n the number of leading dimensions to get, must be &lt;= than {@link
282+
* Shape#numDimensions()}
240283
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
241284
* this Shape
242285
*/
@@ -252,7 +295,9 @@ public Shape take(int n) {
252295

253296
/** Returns a new Shape, with this Shape's first dimension removed. */
254297
public Shape tail() {
255-
if (dimensionSizes.length < 2) return Shape.of();
298+
if (dimensionSizes.length < 2) {
299+
return Shape.of();
300+
}
256301
return Shape.of(Arrays.copyOfRange(dimensionSizes, 1, dimensionSizes.length));
257302
}
258303

@@ -276,15 +321,21 @@ public Shape takeLast(int n) {
276321
}
277322

278323
/**
279-
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
324+
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code
325+
* begin} to {@code end}.
326+
*
280327
* @param begin Where to start the sub-shape.
281328
* @param end Where to end the sub-shape, exclusive.
282329
* @return the sub-shape bounded by begin and end.
283330
*/
284-
public Shape subShape(int begin, int end){
331+
public Shape subShape(int begin, int end) {
285332
if (end > numDimensions()) {
286333
throw new ArrayIndexOutOfBoundsException(
287-
"End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions.");
334+
"End index "
335+
+ end
336+
+ " out of bounds: shape only has "
337+
+ numDimensions()
338+
+ " dimensions.");
288339
}
289340
if (begin < 0) {
290341
throw new ArrayIndexOutOfBoundsException(
@@ -423,7 +474,7 @@ public boolean isCompatibleWith(Shape shape) {
423474
return false;
424475
}
425476
for (int i = 0; i < numDimensions(); i++) {
426-
if (!isCompatible(size(i), shape.size(i))) {
477+
if (!isCompatible(get(i), shape.get(i))) {
427478
return false;
428479
}
429480
}

ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3798,9 +3798,9 @@ private static int[] computeArrayDims(NdArray<?> ndArray, int expectedRank) {
37983798
}
37993799
int[] arrayShape = new int[expectedRank];
38003800
for (int i = 0; i < expectedRank; ++i) {
3801-
long dimSize = shape.size(i);
3801+
long dimSize = shape.get(i);
38023802
if (dimSize > Integer.MAX_VALUE) {
3803-
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.size(i) + ")");
3803+
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.get(i) + ")");
38043804
}
38053805
arrayShape[i] = (int)dimSize;
38063806
}

ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public static DimensionalSpace create(Shape shape) {
2828

2929
// Start from the last dimension, where all elements are continuous
3030
for (int i = dimensions.length - 1, elementSize = 1; i >= 0; --i) {
31-
dimensions[i] = new Axis(shape.size(i), elementSize);
31+
dimensions[i] = new Axis(shape.get(i), elementSize);
3232
elementSize *= dimensions[i].numElements();
3333
}
3434
return new DimensionalSpace(dimensions, shape);
@@ -189,7 +189,9 @@ public long positionOf(long[] coords) {
189189
return position;
190190
}
191191

192-
/** Succinct description of the shape meant for debugging. */
192+
/**
193+
* Succinct description of the shape meant for debugging.
194+
*/
193195
@Override
194196
public String toString() {
195197
return Arrays.toString(dimensions);

ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import static org.tensorflow.ndarray.index.Indices.at;
2525
import static org.tensorflow.ndarray.index.Indices.even;
2626
import static org.tensorflow.ndarray.index.Indices.flip;
27-
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
2827
import static org.tensorflow.ndarray.index.Indices.odd;
2928
import static org.tensorflow.ndarray.index.Indices.range;
3029
import static org.tensorflow.ndarray.index.Indices.seq;
30+
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
3131
import static org.tensorflow.ndarray.index.Indices.sliceTo;
3232

3333
import java.nio.BufferOverflowException;
@@ -132,15 +132,15 @@ public void iterateElements() {
132132
long value = 0L;
133133
for (NdArray<T> matrix : matrix3d.elements(0)) {
134134
assertEquals(2L, matrix.shape().numDimensions());
135-
assertEquals(4L, matrix.shape().size(0));
136-
assertEquals(5L, matrix.shape().size(1));
135+
assertEquals(4L, matrix.shape().get(0));
136+
assertEquals(5L, matrix.shape().get(1));
137137

138138
for (NdArray<T> vector : matrix.elements(0)) {
139-
assertEquals(1L, vector.shape().numDimensions()) ;
140-
assertEquals(5L, vector.shape().size(0));
139+
assertEquals(1L, vector.shape().numDimensions());
140+
assertEquals(5L, vector.shape().get(0));
141141

142142
for (NdArray<T> scalar : vector.scalars()) {
143-
assertEquals(0L, scalar.shape().numDimensions()) ;
143+
assertEquals(0L, scalar.shape().numDimensions());
144144
scalar.setObject(valueOf(value++));
145145
try {
146146
scalar.elements(0);
@@ -162,7 +162,7 @@ public void iterateElements() {
162162
@Test
163163
public void slices() {
164164
NdArray<T> matrix3d = allocate(Shape.of(5, 4, 5));
165-
165+
166166
T val100 = valueOf(100L);
167167
matrix3d.setObject(val100, 1, 0, 0);
168168
T val101 = valueOf(101L);
@@ -318,8 +318,8 @@ public void equalsAndHashCode() {
318318
NdArray<T> array4 = allocate(Shape.of(1, 2, 2));
319319

320320
@SuppressWarnings("unchecked")
321-
T[][][] values = (T[][][])(new Object[][][] {
322-
{ { valueOf(0L), valueOf(1L) }, { valueOf(2L), valueOf(0L) } }
321+
T[][][] values = (T[][][]) (new Object[][][]{
322+
{{valueOf(0L), valueOf(1L)}, {valueOf(2L), valueOf(0L)}}
323323
});
324324

325325
StdArrays.copyTo(values[0], array1);

ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,38 @@
1616
*/
1717
package org.tensorflow.ndarray;
1818

19-
import org.junit.jupiter.api.Test;
19+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
20+
import static org.junit.jupiter.api.Assertions.assertEquals;
21+
import static org.junit.jupiter.api.Assertions.assertFalse;
22+
import static org.junit.jupiter.api.Assertions.assertNotEquals;
23+
import static org.junit.jupiter.api.Assertions.assertNotNull;
24+
import static org.junit.jupiter.api.Assertions.assertTrue;
25+
import static org.junit.jupiter.api.Assertions.fail;
2026

21-
import static org.junit.jupiter.api.Assertions.*;
27+
import org.junit.jupiter.api.Test;
2228

2329
public class ShapeTest {
2430

2531
@Test
2632
public void allKnownDimensions() {
2733
Shape shape = Shape.of(5, 4, 5);
2834
assertEquals(3, shape.numDimensions());
29-
assertEquals(5, shape.size(0));
30-
assertEquals(4, shape.size(1));
31-
assertEquals(5, shape.size(2));
35+
assertEquals(5, shape.get(0));
36+
assertEquals(4, shape.get(1));
37+
assertEquals(5, shape.get(2));
3238
assertEquals(100, shape.size());
33-
assertArrayEquals(new long[] {5, 4, 5}, shape.asArray());
39+
assertArrayEquals(new long[]{5, 4, 5}, shape.asArray());
3440
try {
35-
shape.size(3);
41+
shape.get(3);
3642
fail();
3743
} catch (IndexOutOfBoundsException e) {
3844
// as expected
3945
}
40-
assertEquals(5, shape.size(-1));
41-
assertEquals(4, shape.size(-2));
42-
assertEquals(5, shape.size(-3));
46+
assertEquals(5, shape.get(-1));
47+
assertEquals(4, shape.get(-2));
48+
assertEquals(5, shape.get(-3));
4349
try {
44-
shape.size(-4);
50+
shape.get(-4);
4551
fail();
4652
} catch (IndexOutOfBoundsException e) {
4753
// as expected
@@ -133,7 +139,7 @@ public void testShapeModification() {
133139
long[] internalShape = one.asArray();
134140
assertNotNull(internalShape);
135141
internalShape[0] = 42L;
136-
assertEquals(2L, one.size(0));
142+
assertEquals(2L, one.get(0));
137143
}
138144

139145
@Test

0 commit comments

Comments
 (0)