Skip to content

Commit 921e077

Browse files
authored
Merge pull request #1 from rnett/feature/tensor-datatostring
Data type tests, wrap strings in quotes
2 parents 98d320d + 6f8dedc commit 921e077

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensors.java

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import java.util.Iterator;
55
import java.util.List;
66
import java.util.StringJoiner;
7-
import org.tensorflow.Tensor;
87
import org.tensorflow.ndarray.NdArray;
98
import org.tensorflow.ndarray.Shape;
9+
import org.tensorflow.proto.framework.DataType;
1010

1111
/**
1212
* Tensor helper methods.
@@ -58,42 +58,57 @@ public static String toString(Tensor tensor, Integer maxWidth) {
5858
}
5959
return String.valueOf(iterator.next().getObject());
6060
}
61-
return toString(iterator, shape, 0, maxWidth);
61+
return toString(iterator, tensor.dataType(), shape, 0, maxWidth);
6262
}
6363

6464
/**
65-
* @param iterator an iterator over the scalars
66-
* @param shape the shape of the tensor
67-
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited).
68-
* This limit may surpassed if the first or last element are too long.
65+
* Convert an element of a tensor to string, in a way that may depend on the data type.
66+
*
67+
* @param dtype the tensor's data type
68+
* @param data the element
69+
* @return the element's string representation
70+
*/
71+
private static String elementToString(DataType dtype, Object data) {
72+
if (dtype == DataType.DT_STRING) {
73+
return '"' + data.toString() + '"';
74+
} else {
75+
return data.toString();
76+
}
77+
}
78+
79+
/**
80+
* @param iterator an iterator over the scalars
81+
* @param shape the shape of the tensor
82+
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited). This limit may surpassed
83+
* if the first or last element are too long.
6984
* @param dimension the current dimension being processed
7085
* @return the String representation of the tensor data at {@code dimension}
7186
*/
72-
private static String toString(Iterator<? extends NdArray<?>> iterator, Shape shape,
87+
private static String toString(Iterator<? extends NdArray<?>> iterator, DataType dtype, Shape shape,
7388
int dimension, Integer maxWidth) {
7489
if (dimension < shape.numDimensions() - 1) {
7590
StringJoiner joiner = new StringJoiner("\n", indent(dimension) + "[\n",
7691
"\n" + indent(dimension) + "]");
7792
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
78-
String element = toString(iterator, shape, dimension + 1, maxWidth);
93+
String element = toString(iterator, dtype, shape, dimension + 1, maxWidth);
7994
joiner.add(element);
8095
}
8196
return joiner.toString();
8297
}
8398
if (maxWidth == null) {
8499
StringJoiner joiner = new StringJoiner(", ", indent(dimension) + "[", "]");
85100
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
86-
String element = iterator.next().getObject().toString();
87-
joiner.add(element);
101+
Object element = iterator.next().getObject();
102+
joiner.add(elementToString(dtype, element));
88103
}
89104
return joiner.toString();
90105
}
91106
List<Integer> lengths = new ArrayList<>();
92107
StringJoiner joiner = new StringJoiner(", ", indent(dimension) + "[", "]");
93108
int lengthBefore = "]".length();
94109
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
95-
String element = iterator.next().getObject().toString();
96-
joiner.add(element);
110+
Object element = iterator.next().getObject();
111+
joiner.add(elementToString(dtype, element));
97112
int addedLength = joiner.length() - lengthBefore;
98113
lengths.add(addedLength);
99114
lengthBefore += addedLength;

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -544,26 +544,26 @@ public void gracefullyFailCreationFromNullArrayForStringTensor() {
544544

545545
@Test
546546
public void dataToString() {
547-
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1}))) {
547+
try (TInt32 t = TInt32.vectorOf(3, 0, 1)) {
548548
String actual = t.dataToString();
549549
assertEquals("[3, 0, 1]", actual);
550550
}
551-
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1}))) {
551+
try (TInt32 t = TInt32.vectorOf(3, 0, 1)) {
552552
String actual = t.dataToString(Tensor.maxWidth(5));
553553
// Cannot remove first or last element
554554
assertEquals("[3, 0, 1]", actual);
555555
}
556-
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1}))) {
556+
try (TInt32 t = TInt32.vectorOf(3, 0, 1)) {
557557
String actual = t.dataToString(Tensor.maxWidth(6));
558558
// Do not insert ellipses if it increases the length
559559
assertEquals("[3, 0, 1]", actual);
560560
}
561-
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1, 2}))) {
561+
try (TInt32 t = TInt32.vectorOf(3, 0, 1, 2)) {
562562
String actual = t.dataToString(Tensor.maxWidth(11));
563563
// Limit may be surpassed if first or last element are too long
564564
assertEquals("[3, ..., 2]", actual);
565565
}
566-
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1, 2}))) {
566+
try (TInt32 t = TInt32.vectorOf(3, 0, 1, 2)) {
567567
String actual = t.dataToString(Tensor.maxWidth(12));
568568
assertEquals("[3, 0, 1, 2]", actual);
569569
}
@@ -574,10 +574,27 @@ public void dataToString() {
574574
+ " [3, 2, 1]\n"
575575
+ "]", actual);
576576
}
577-
try (RawTensor t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1, 2})).asRawTensor()) {
577+
try (RawTensor t = TInt32.vectorOf(3, 0, 1, 2).asRawTensor()) {
578578
String actual = t.dataToString(Tensor.maxWidth(12));
579579
assertEquals("[3, 0, 1, 2]", actual);
580580
}
581+
// different data types
582+
try (RawTensor t = TFloat32.vectorOf(3.0101f, 0, 1.5f, 2).asRawTensor()) {
583+
String actual = t.dataToString();
584+
assertEquals("[3.0101, 0.0, 1.5, 2.0]", actual);
585+
}
586+
try (RawTensor t = TFloat64.vectorOf(3.0101, 0, 1.5, 2).asRawTensor()) {
587+
String actual = t.dataToString();
588+
assertEquals("[3.0101, 0.0, 1.5, 2.0]", actual);
589+
}
590+
try (RawTensor t = TBool.vectorOf(true, true, false, true).asRawTensor()) {
591+
String actual = t.dataToString();
592+
assertEquals("[true, true, false, true]", actual);
593+
}
594+
try (RawTensor t = TString.vectorOf("a", "b", "c").asRawTensor()) {
595+
String actual = t.dataToString();
596+
assertEquals("[\"a\", \"b\", \"c\"]", actual);
597+
}
581598
}
582599

583600
// Workaround for cross compiliation

0 commit comments

Comments
 (0)