27
27
import java .util .Arrays ;
28
28
import java .util .List ;
29
29
30
- /**
31
- * Various methods for processing with Shapes and Operands
32
- */
30
+ /** Various methods for processing with Shapes and Operands */
33
31
public class ShapeUtils {
34
32
35
33
/**
@@ -82,12 +80,12 @@ public static <T extends TNumber> long[] getLongArray(Scope scope, Operand<T> di
82
80
Operand <TInt64 > ldims = (Operand <TInt64 >) dims ;
83
81
ldims .asOutput ().data ().scalars ().forEach (s -> result .add (s .getLong ()));
84
82
} else if (dType .equals (TUint8 .DTYPE )) {
85
- @ SuppressWarnings ("unchecked" )
86
- Operand <TUint8 > udims = (Operand <TUint8 >) dims ;
83
+ @ SuppressWarnings ("unchecked" )
84
+ Operand <TUint8 > udims = (Operand <TUint8 >) dims ;
87
85
udims .asOutput ().data ().scalars ().forEach (s -> result .add (s .getObject ().longValue ()));
88
- } else { // shouldn't happen
89
- throw new IllegalArgumentException ("the data type must be an integer type" );
90
- }
86
+ } else { // shouldn't happen
87
+ throw new IllegalArgumentException ("the data type must be an integer type" );
88
+ }
91
89
92
90
} else {
93
91
try (Session session = new Session ((Graph ) scope .env ())) {
@@ -96,17 +94,17 @@ public static <T extends TNumber> long[] getLongArray(Scope scope, Operand<T> di
96
94
session .runner ().fetch (dims ).run ().get (0 ).expect (TInt32 .DTYPE )) {
97
95
tensorResult .data ().scalars ().forEach (s -> result .add ((long ) s .getInt ()));
98
96
}
99
- } else if (dType .equals (TInt64 .DTYPE )){
97
+ } else if (dType .equals (TInt64 .DTYPE )) {
100
98
try (Tensor <TInt64 > tensorResult =
101
99
session .runner ().fetch (dims ).run ().get (0 ).expect (TInt64 .DTYPE )) {
102
100
tensorResult .data ().scalars ().forEach (s -> result .add (s .getLong ()));
103
101
}
104
- }else if (dType .equals (TUint8 .DTYPE )){
102
+ } else if (dType .equals (TUint8 .DTYPE )) {
105
103
try (Tensor <TUint8 > tensorResult =
106
- session .runner ().fetch (dims ).run ().get (0 ).expect (TUint8 .DTYPE )) {
104
+ session .runner ().fetch (dims ).run ().get (0 ).expect (TUint8 .DTYPE )) {
107
105
tensorResult .data ().scalars ().forEach (s -> result .add (s .getObject ().longValue ()));
108
106
}
109
- } else { // shouldn't happen
107
+ } else { // shouldn't happen
110
108
throw new IllegalArgumentException ("the data type must be an integer type" );
111
109
}
112
110
}
@@ -156,6 +154,14 @@ public static <T extends TNumber> Shape getShape(Tensor<T> tensor) {
156
154
* Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
157
155
* </code> is not compatible with <code>Shape(4, 4)</code>.
158
156
*
157
+ * <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
158
+ * of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
159
+ * at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
160
+ *
161
+ * <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
162
+ * one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
163
+ * is "stretched" with dimensions of 1. See {@link org.tensorflow.op.Ops#broadcastTo}.
164
+ *
159
165
* @param a The first shape
160
166
* @param b The second shape
161
167
* @return true, if the two shapes are compatible.
@@ -175,7 +181,7 @@ public static boolean isCompatibleWith(Shape a, Shape b) {
175
181
}
176
182
177
183
/**
178
- * Determines if a shape is an unknown shape as provided in <cade >Shape.unknown()</code>.
184
+ * Determines if a shape is an unknown shape as provided in <code >Shape.unknown()</code>.
179
185
*
180
186
* @param a the shape to test.
181
187
* @return true if the shape is an unknown shape
@@ -186,7 +192,8 @@ public static boolean isUnknownShape(Shape a) {
186
192
187
193
/**
188
194
* Reduces the shape by eliminating trailing Dimensions.
189
- * <p>The last dimension, specified by axis, will be a product of all remaining dimensions</p>
195
+ *
196
+ * <p>The last dimension, specified by axis, will be a product of all remaining dimensions
190
197
*
191
198
* @param shape the shape to squeeze
192
199
* @param axis the axis to squeeze
@@ -198,14 +205,12 @@ public static Shape reduce(Shape shape, int axis) {
198
205
axis = shape .numDimensions () + axis ;
199
206
}
200
207
long [] array = shape .asArray ();
201
- if (array == null )
202
- return Shape .unknown ();
208
+ if (array == null ) return Shape .unknown ();
203
209
long [] newArray = new long [axis ];
204
210
System .arraycopy (array , 0 , newArray , 0 , axis - 1 );
205
211
long prod = array [axis - 1 ];
206
212
for (int i = axis ; i < array .length ; i ++) {
207
- if (array [i ] != Shape .UNKNOWN_SIZE )
208
- prod *= array [i ];
213
+ if (array [i ] != Shape .UNKNOWN_SIZE ) prod *= array [i ];
209
214
}
210
215
newArray [axis - 1 ] = prod ;
211
216
return Shape .of (newArray );
0 commit comments