17
17
18
18
package org .tensorflow .ndarray ;
19
19
20
+ import java .util .ArrayList ;
20
21
import java .util .Arrays ;
22
+ import java .util .List ;
21
23
22
24
/**
23
25
* The shape of a Tensor or {@link NdArray}.
@@ -74,8 +76,8 @@ public static Shape scalar() {
74
76
* Shape scalar = Shape.of()
75
77
* }</pre>
76
78
*
77
- * @param dimensionSizes number of elements in each dimension of this shape, if any, or
78
- * {@link Shape#UNKNOWN_SIZE} if unknown.
79
+ * @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link
80
+ * Shape#UNKNOWN_SIZE} if unknown.
79
81
* @return a new shape
80
82
*/
81
83
public static Shape of (long ... dimensionSizes ) {
@@ -108,13 +110,34 @@ public long size() {
108
110
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
109
111
*
110
112
* @param i the index of the dimension to get the size for. If this Shape has a known number of
111
- * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
112
- * case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
113
- * size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
113
+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in
114
+ * which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
115
+ * returns the size of the last dimension, {@code size(-2)} the size of the second to last
116
+ * dimension etc.
114
117
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
115
118
* otherwise.
119
+ * @deprecated Renamed to {@link #get(int)}.
116
120
*/
117
- public long size (int i ) {
121
+ @ Deprecated
122
+ public long size (int i ){
123
+ return get (i );
124
+ }
125
+
126
+ /**
127
+ * The size of the dimension with the given index.
128
+ *
129
+ * <p>If {@link Shape#isUnknown()} is true or the size of the dimension with the given index has
130
+ * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
131
+ *
132
+ * @param i the index of the dimension to get the size for. If this Shape has a known number of
133
+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in
134
+ * which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
135
+ * returns the size of the last dimension, {@code size(-2)} the size of the second to last
136
+ * dimension etc.
137
+ * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
138
+ * otherwise.
139
+ */
140
+ public long get (int i ) {
118
141
if (dimensionSizes == null ) {
119
142
return UNKNOWN_SIZE ;
120
143
} else if (i >= 0 ) {
@@ -177,6 +200,24 @@ public long[] asArray() {
177
200
}
178
201
}
179
202
203
+ /**
204
+ * Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change
205
+ * this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
206
+ */
207
+ public List <Long > toListOrNull () {
208
+ long [] array = asArray ();
209
+ if (array == null ) {
210
+ return null ;
211
+ }
212
+
213
+ List <Long > list = new ArrayList <>(array .length );
214
+ for (long l : array ) {
215
+ list .add (l );
216
+ }
217
+
218
+ return list ;
219
+ }
220
+
180
221
@ Override
181
222
public int hashCode () {
182
223
return dimensionSizes != null ? Arrays .hashCode (dimensionSizes ) : super .hashCode ();
@@ -186,6 +227,7 @@ public int hashCode() {
186
227
* Equals implementation for Shapes. Two Shapes are considered equal iff:
187
228
*
188
229
* <p>
230
+ *
189
231
* <ul>
190
232
* <li>the number of dimensions is defined and equal for both
191
233
* <li>the size of each dimension is defined and equal for both
@@ -236,7 +278,8 @@ public Shape head() {
236
278
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
237
279
* shape
238
280
*
239
- * @param n the number of leading dimensions to get, must be <= than {@link Shape#numDimensions()}
281
+ * @param n the number of leading dimensions to get, must be <= than {@link
282
+ * Shape#numDimensions()}
240
283
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
241
284
* this Shape
242
285
*/
@@ -252,7 +295,9 @@ public Shape take(int n) {
252
295
253
296
/** Returns a new Shape, with this Shape's first dimension removed. */
254
297
public Shape tail () {
255
- if (dimensionSizes .length < 2 ) return Shape .of ();
298
+ if (dimensionSizes .length < 2 ) {
299
+ return Shape .of ();
300
+ }
256
301
return Shape .of (Arrays .copyOfRange (dimensionSizes , 1 , dimensionSizes .length ));
257
302
}
258
303
@@ -276,15 +321,21 @@ public Shape takeLast(int n) {
276
321
}
277
322
278
323
/**
279
- * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
324
+ * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code
325
+ * begin} to {@code end}.
326
+ *
280
327
* @param begin Where to start the sub-shape.
281
328
* @param end Where to end the sub-shape, exclusive.
282
329
* @return the sub-shape bounded by begin and end.
283
330
*/
284
- public Shape subShape (int begin , int end ){
331
+ public Shape subShape (int begin , int end ) {
285
332
if (end > numDimensions ()) {
286
333
throw new ArrayIndexOutOfBoundsException (
287
- "End index " + end + " out of bounds: shape only has " + numDimensions () + " dimensions." );
334
+ "End index "
335
+ + end
336
+ + " out of bounds: shape only has "
337
+ + numDimensions ()
338
+ + " dimensions." );
288
339
}
289
340
if (begin < 0 ) {
290
341
throw new ArrayIndexOutOfBoundsException (
@@ -423,7 +474,7 @@ public boolean isCompatibleWith(Shape shape) {
423
474
return false ;
424
475
}
425
476
for (int i = 0 ; i < numDimensions (); i ++) {
426
- if (!isCompatible (size (i ), shape .size (i ))) {
477
+ if (!isCompatible (get (i ), shape .get (i ))) {
427
478
return false ;
428
479
}
429
480
}
0 commit comments