Skip to content

Commit 3261888

Browse files
committed
Update JavaDoc to highlight difference between compatible shapes and broadcastable shapes.
1 parent 33530bb commit 3261888

File tree

1 file changed

+23
-18
lines changed
  • tensorflow-framework/src/main/java/org/tensorflow/framework/utils

1 file changed

+23
-18
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java

+23-18
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
import java.util.Arrays;
2828
import java.util.List;
2929

30-
/**
31-
* Various methods for processing with Shapes and Operands
32-
*/
30+
/** Various methods for processing with Shapes and Operands */
3331
public class ShapeUtils {
3432

3533
/**
@@ -82,12 +80,12 @@ public static <T extends TNumber> long[] getLongArray(Scope scope, Operand<T> di
8280
Operand<TInt64> ldims = (Operand<TInt64>) dims;
8381
ldims.asOutput().data().scalars().forEach(s -> result.add(s.getLong()));
8482
} else if (dType.equals(TUint8.DTYPE)) {
85-
@SuppressWarnings("unchecked")
86-
Operand<TUint8> udims = (Operand<TUint8>) dims;
83+
@SuppressWarnings("unchecked")
84+
Operand<TUint8> udims = (Operand<TUint8>) dims;
8785
udims.asOutput().data().scalars().forEach(s -> result.add(s.getObject().longValue()));
88-
}else { // shouldn't happen
89-
throw new IllegalArgumentException("the data type must be an integer type");
90-
}
86+
} else { // shouldn't happen
87+
throw new IllegalArgumentException("the data type must be an integer type");
88+
}
9189

9290
} else {
9391
try (Session session = new Session((Graph) scope.env())) {
@@ -96,17 +94,17 @@ public static <T extends TNumber> long[] getLongArray(Scope scope, Operand<T> di
9694
session.runner().fetch(dims).run().get(0).expect(TInt32.DTYPE)) {
9795
tensorResult.data().scalars().forEach(s -> result.add((long) s.getInt()));
9896
}
99-
} else if (dType.equals(TInt64.DTYPE)){
97+
} else if (dType.equals(TInt64.DTYPE)) {
10098
try (Tensor<TInt64> tensorResult =
10199
session.runner().fetch(dims).run().get(0).expect(TInt64.DTYPE)) {
102100
tensorResult.data().scalars().forEach(s -> result.add(s.getLong()));
103101
}
104-
}else if (dType.equals(TUint8.DTYPE)){
102+
} else if (dType.equals(TUint8.DTYPE)) {
105103
try (Tensor<TUint8> tensorResult =
106-
session.runner().fetch(dims).run().get(0).expect(TUint8.DTYPE)) {
104+
session.runner().fetch(dims).run().get(0).expect(TUint8.DTYPE)) {
107105
tensorResult.data().scalars().forEach(s -> result.add(s.getObject().longValue()));
108106
}
109-
}else { // shouldn't happen
107+
} else { // shouldn't happen
110108
throw new IllegalArgumentException("the data type must be an integer type");
111109
}
112110
}
@@ -156,6 +154,14 @@ public static <T extends TNumber> Shape getShape(Tensor<T> tensor) {
156154
* Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
157155
* </code> is not compatible with <code>Shape(4, 4)</code>.
158156
*
157+
* <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
158+
* of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
159+
* at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
160+
*
161+
* <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
162+
* one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
163+
* is "stretched" with dimensions of 1. See {@link org.tensorflow.op.Ops#broadcastTo}.
164+
*
159165
* @param a The first shape
160166
* @param b The second shape
161167
* @return true, if the two shapes are compatible.
@@ -175,7 +181,7 @@ public static boolean isCompatibleWith(Shape a, Shape b) {
175181
}
176182

177183
/**
178-
* Determines if a shape is an unknown shape as provided in <cade>Shape.unknown()</code>.
184+
* Determines if a shape is an unknown shape as provided in <code>Shape.unknown()</code>.
179185
*
180186
* @param a the shape to test.
181187
* @return true if the shape is an unknown shape
@@ -186,7 +192,8 @@ public static boolean isUnknownShape(Shape a) {
186192

187193
/**
188194
* Reduces the shape by eliminating trailing Dimensions.
189-
* <p>The last dimension, specified by axis, will be a product of all remaining dimensions</p>
195+
*
196+
* <p>The last dimension, specified by axis, will be a product of all remaining dimensions
190197
*
191198
* @param shape the shape to squeeze
192199
* @param axis the axis to squeeze
@@ -198,14 +205,12 @@ public static Shape reduce(Shape shape, int axis) {
198205
axis = shape.numDimensions() + axis;
199206
}
200207
long[] array = shape.asArray();
201-
if(array == null)
202-
return Shape.unknown();
208+
if (array == null) return Shape.unknown();
203209
long[] newArray = new long[axis];
204210
System.arraycopy(array, 0, newArray, 0, axis - 1);
205211
long prod = array[axis - 1];
206212
for (int i = axis; i < array.length; i++) {
207-
if(array[i] != Shape.UNKNOWN_SIZE)
208-
prod *= array[i];
213+
if (array[i] != Shape.UNKNOWN_SIZE) prod *= array[i];
209214
}
210215
newArray[axis - 1] = prod;
211216
return Shape.of(newArray);

0 commit comments

Comments
 (0)